In [1]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, optimizers, datasets
from tensorflow.python.ops import custom_gradient
from datetime import datetime
from packaging import version
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

## Model definition
### To make use of tf recompute grad, one has to manually split their 'big' model into manageable 'blocks' (each block corresponds to a logical 'checkpoint'). Below is an example of a simple CNN model that has been split into 2 blocks.

In [2]:
def simple_block1(img_dim, n_channels):
    model = tf.keras.Sequential([
    layers.Reshape(
        target_shape=[img_dim, img_dim, n_channels],
        input_shape=(img_dim, img_dim, n_channels)),
    layers.Conv2D(1, 2, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same')])
    return model

def simple_block2():
    model = tf.keras.Sequential([
    layers.Input(shape=(2,2,1)),
    layers.Conv2D(1, 2, padding='same', activation=tf.nn.relu),
    layers.MaxPooling2D((1, 1), padding='same'),
    layers.Flatten(),
    layers.Dense(2)])
    return model

In [3]:
def compute_loss(logits, labels):
  return tf.reduce_mean(
      tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=logits, labels=labels))

In [4]:
def train_step(n_steps=1):
    img_dim = 2
    n_channels = 1
    bs = 1 
    x = tf.ones([bs, img_dim,img_dim,n_channels])
    y = tf.ones([bs], dtype=tf.int64)
    # define your model and decorate it as follows
    bk1_orig = simple_block1(img_dim, n_channels)
    bk2_orig =  simple_block2()
    # this is how you invoke the tf recompute_grad decorator on your block. Do this for all your blocks
    bk1 = tf.recompute_grad(bk1_orig)
    bk2 = tf.recompute_grad(bk2_orig)
    optimizer = optimizers.SGD()
    tr_vars = bk1_orig.trainable_variables + bk2_orig.trainable_variables
    for _ in range(n_steps):
        with tf.GradientTape() as tape:
            logits1 = bk1(x, trainable_variables=bk1_orig.trainable_variables)
            logits2 = bk2(logits1, trainable_variables=bk2_orig.trainable_variables)
            loss  = compute_loss(logits2, y)
            print('loss', loss)
        grads = tape.gradient(loss, tr_vars) # tr_vars
        optimizer.apply_gradients(zip(grads, tr_vars))
        del grads 

In [5]:
train_step(3)

loss tf.Tensor(0.6931472, shape=(), dtype=float32)
loss tf.Tensor(0.6881597, shape=(), dtype=float32)
loss tf.Tensor(0.68322194, shape=(), dtype=float32)
