This notebook shows how to export into the SavedModel format a custom training loop and for later training in a different environment

In [None]:
!pip install tensorflow==2.3.0
import os

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import tensorflow as tf
import tensorflow.keras as keras

# Training Step in tf.function

In [None]:
np.random.seed(2); tf.random.set_seed(5)

def make_model():
    # this constructs a keras Model. We use the functional API and add a custom
    # layer for demo purposes but a model of any complexity can be used here
    from tensorflow.keras import layers
    
    class CustomLayer(keras.layers.Layer):
        def __init__(self, **kwargs):
            super().__init__(**kwargs)
            l2_reg = keras.regularizers.l2(0.1)
            self.dense = layers.Dense(1, kernel_regularizer=l2_reg, 
                                      name='my_layer_dense')
            
        def call(self, data):
            return self.dense(data)
    inputs = keras.Input(shape=(8,))
    x1 = layers.Dense(30, activation="relu", name='my_dense')(inputs)
    outputs = CustomLayer()(x1)
    return keras.Model(inputs=inputs, outputs=outputs)

# Prepare the training dataset.
def get_housing_dataset():
    from sklearn.datasets import fetch_california_housing
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import StandardScaler
    housing = fetch_california_housing()

    X_train_full, X_test, y_train_full, y_test = train_test_split(
        housing.data, housing.target)
    X_train, X_valid, y_train, y_valid = train_test_split(
        X_train_full, y_train_full)

    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train).astype(np.float32)
    X_valid = scaler.transform(X_valid).astype(np.float32)
    X_test = scaler.transform(X_test).astype(np.float32)
    return X_train, X_valid, X_test, y_train.astype(np.float32), \
           y_valid.astype(np.float32), y_test.astype(np.float32)

X_train, X_valid,_, y_train, y_valid, _ = get_housing_dataset()

batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)
valid_dataset = tf.data.Dataset.from_tensor_slices((X_valid, y_valid)).batch(batch_size)


class CustomModule(tf.Module):

    def __init__(self):
        super(CustomModule, self).__init__()
        self.model = make_model()
        self.opt = keras.optimizers.Adam(learning_rate=0.001)
        
    # add @tf.function here to make it faster (run in graph mode) and ensure the right shapes and types
    # are used (optional). 
    # To debug we can 
    # - either use tf.print() statements that will execute in graph mode
    # - or run in eager mode by removing the @tf.function annotation or by specifying
    #   tf.config.experimental_run_functions_eagerly(True). In eager mode print() or any python 
    #   statement can be used (instead of tf.print()) and we can use debugger breakpoint
    
    @tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32)])
    def __call__(self, X):
        return self.model(X)

    # the my_train function processes one batch (one step): computes the loss and apply the
    # loss gradient to update the model weights
    @tf.function(input_signature=[tf.TensorSpec([None, 8], tf.float32), tf.TensorSpec([None], tf.float32)])
    def my_train(self, X, y):
        with tf.GradientTape() as tape:
            logits = self.model(X, training=True)  
            main_loss = tf.reduce_mean(keras.losses.mean_squared_error(y, logits))
            # self.model.losses contains the reularization loss (see l2_reg above)
            loss_value = tf.add_n([main_loss] + self.model.losses) 

        grads = tape.gradient(loss_value, self.model.trainable_weights)
        self.opt.apply_gradients(zip(grads, self.model.trainable_weights))
        return loss_value

# set to True to force in eager execution despite @tf.functions (debugging)
tf.config.run_functions_eagerly(False)

# instantiate the module
module = CustomModule()

# demo a call to the module. (calls the __call__() method)
print('sample prediction: ', module(X_train[0:1]).numpy())

Let's call the `my_train` function repeatedly to train the model

In [None]:
def train_module(module, train_dataset, valid_dataset):
    valid_metric = keras.metrics.MeanSquaredError()
    loss_hist = []
    step=1
    for epoch in range(3):
        for X, y in train_dataset:
            loss = module.my_train(X, y)
            loss_hist.append(loss.numpy())
        
            if step % 100 == 0:
                for (X_val, y_val) in valid_dataset:
                    val_logits = module(X_val)
                    valid_metric.update_state(y_val, val_logits)
                print(f'Mean squared error: step {step}: {valid_metric.result()}')
            step+=1
    return loss_hist    

def plot_loss(loss_hist):
    plt.figure(figsize=(8,4))
    plt.title('loss', fontsize=15)
    plt.plot(loss_hist)
    plt.grid()
    
# train the module
loss_hist = train_module(module, train_dataset, valid_dataset)
plot_loss(loss_hist)

Let's check the state of the `ADAM` optimizer that we use. ADAM learns two variables `m` and `v` associated to each weight.

`m` and `v` are estimates of the first moment (the mean) and the second moment (the uncentered variance) of the gradients respectively  
For more info: https://ruder.io/optimizing-gradient-descent/index.html#adam

In Tensorflow terms they are called `slots`. For more info about how they are tracked see https://www.tensorflow.org/guide/checkpoint#loading_mechanics

Let's examine the content of ADAM's m slot for the bias of the first dense layer

In [None]:
module.opt.weights[2]

Let's train a bit further. As expected the loss remains low (note the scale of this plot)

In [None]:
loss_hist = train_module(module, train_dataset, valid_dataset)
plot_loss(loss_hist)

We can check then that the variables of ADAM optimizer changed as well: 

In [None]:
module.opt.weights[2]

# Persist Model and continue Training

All the above was done in memory. Similarly when we save a model it is nice that not only the weights of the layers are saved but also the optimizer state. Then when we continue training a reloaded a model the optimizer doesn't need to re-learn its variables (the `m` and `v` moments in the case of ADAM).

Let's save the module in the SavedModel format. The SavedModel format contains signatures that describe the exported functions with their inputs and outputs available when we load the model. For more info see https://www.tensorflow.org/guide/saved_model

Then we'll inspect the content of the checkpoint saved with the model

In [None]:
def save_module(module, model_dir):
    
    # When saving a tf.keras.Model with either model.save() or 
    # tf.keras.models.save_model() or tf.saved_model.save(),
    # the saved model contains a `serving_default` signature used to get the 
    # output of the model from an input sample. But here we don't save a keras 
    # Model but a tf.Module. This requires to specify the signatures manually
    
    # Note that we also export the training function here
    
    tf.saved_model.save(module, model_dir, 
        signatures={
            'my_serve' : 
            module.__call__.get_concrete_function(tf.TensorSpec([None, 8], tf.float32)),
            'my_train' : 
            module.my_train.get_concrete_function(tf.TensorSpec([None, 8], tf.float32), 
                                                  tf.TensorSpec([None], tf.float32))})
    
def inspect_checkpoint(checkpoint, print_values=False, variables=None):
    if not variables:
        variables = [var_name for (var_name, shape) in tf.train.list_variables(checkpoint)]

    checkpoint_reader = tf.train.load_checkpoint(checkpoint)
    for var_name in variables:
            
        try:
            tensor = checkpoint_reader.get_tensor(var_name)
        except Exception as e:
            print('ignored   : %s (exception %s)' % (var_name, str(type(e))))
            continue
        if isinstance(tensor, np.ndarray):
            if print_values:
                print('tensor    : ', var_name, tensor.shape, tensor)
            else:
                print('tensor    : ', var_name, tensor.shape)
        else:
            if print_values:
                print('non-tensor: ', var_name, type(tensor), tensor)
            else:
                print('non-tensor: ', var_name, type(tensor))

model_dir = 'saved_model'
os.makedirs(model_dir, exist_ok=True)

save_module(module, model_dir)
         
inspect_checkpoint(model_dir + '/variables/variables')

We can see not only the layers weights but also the ADAM's slot variables (`m` and `v`) discussed earlier.  

Excerpt from https://www.tensorflow.org/guide/saved_model#saving_a_custom_model

> When you save a `tf.Module`, any `tf.Variable` attributes, `tf.function`-decorated methods, and `tf.Modules` found via recursive traversal are saved. (See the [Checkpoint tutorial](https://www.tensorflow.org/guide/checkpoint) for more about this recursive traversal.)


This is covered in great lengths in https://www.tensorflow.org/guide/checkpoint#loading_mechanics  

Let's examine the content of ADAM's m slot for the bias of the first dense layer

In [None]:
inspect_checkpoint(model_dir + '/variables/variables', print_values=True, 
                   variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])

We can see that this is exactly the in-memory content of that slot checked above with
```python
module.opt.weights[2]
``` 
This shows that the ADAM's state was indeed saved in the checkpoint. Let's have a look at the exported signatures with the `saved_model_cli` tool bundled with Tensorflow

In [None]:
!saved_model_cli show --all --dir $model_dir

We see the expected signatures for the prediction (`my_serve`) and training (`my_training`) exported functions. More on this later.   
Let's create a fresh instance of the module, save it untrained and reload it 

In [None]:
# instantiate a new module and save it untrained
module = CustomModule()
save_module(module, model_dir)

del module

print('\n\n========== Reload module ===========')

# the following works also if we reload in another python process
model_dir = 'saved_model'
new_module = tf.keras.models.load_model(model_dir)

What's noteworthy here is that the loaded `new_module` object is not a `tf.Module` instance but another kind of object that still offers our `my_train` and `__call__`  functions that we exported. Let's call the `__call__()` method to see that it works (will yield anything since the model is not yet trained)

In [None]:
print('type of reloaded module:', type(new_module))
print('type of instantiated module:', type(CustomModule()))
print('my_train function:', new_module.my_train)
print('__call__ function:', new_module.__call__)

# demo a call to the module. (calls the __call__() method)
print('sample prediction: ', new_module(X_train[0:1]).numpy())

As a consequence we can still use the `my_train` function with our `train_module` function. Let's train the reloaded module and save it afterwards

In [None]:
np.random.seed(3); tf.random.set_seed(5)
loss_hist = train_module(new_module, train_dataset, valid_dataset)
plot_loss(loss_hist)

save_module(new_module, model_dir)

The above shows that we can load a module and train it exactly as if we had instantiated it with `CustomModule()`. Let's check some of the optimizer state like we did above

In [None]:
inspect_checkpoint(model_dir + '/variables/variables', print_values=True, 
                   variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])

Reload the module, continue the training and save it

In [None]:
del new_module
new_module_2 = tf.keras.models.load_model(model_dir)

loss_hist = train_module(new_module_2, train_dataset, valid_dataset)
plot_loss(loss_hist)

save_module(new_module_2, model_dir)

Like what we did in memory-only earlier the above shows that the weights have been correctly reloaded and that we didn't restart the training from scratch. How about the ADAM's variables ?

In [None]:
inspect_checkpoint(model_dir + '/variables/variables', print_values=True, 
                   variables=['model/layer_with_weights-0/bias/.OPTIMIZER_SLOT/opt/m/.ATTRIBUTES/VARIABLE_VALUE'])

The ADAM's variables have changed as well. They change less and less as training continues which can be shown by taking the norm of the difference of those slots before and after training and seeing that this norm decreases with time. 

This shows that the optimizer state is also captured in the saved model and that we can stop and resume training without losing anything. 

# Low Level Operations
All the above was done with python objects and methods available when reloading a module. But how can we do it in another language where only the graph and operations are available ? 

Let's see first how to do it in python. Here's again the `saved_model_cli` ouput for the `my_train` signature:
```
signature_def['my_train']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['X'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 8)
        name: my_train_X:0
    inputs['y'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1)
        name: my_train_y:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: ()
        name: StatefulPartitionedCall_1:0
  Method name is: tensorflow/serving/predict
 ```
 
It turns out we can access the input and output tensors by the names shown here. For example:
 - `inputs['X']` has name `my_train_x:0`
 - `output['output_0']` (the loss) has name `StatefulPartitionedCall_1:0`  

Something that is hidden is the operation and tensor used to save the model
- name of the checkpoint: `saver_filename:0` : must point to `model_dir + '/variables/variables'`
- save operation: `StatefulPartitionedCall_2:0`: the next `StatefulPartitionedCall` after the ones exported by our module

The information about the save operation is of course not documented, probably intentionally, so this might not work in future tensorflow versions.

In [None]:
def train_predict_serve(model_dir):
    tf.compat.v1.reset_default_graph()
    session = tf.compat.v1.Session()
    tf.compat.v1.saved_model.loader.load(session, tags=[tf.saved_model.SERVING], export_dir=model_dir)
    graph = session.graph
    operations=graph.get_operations()
    
    input_X = graph.get_tensor_by_name('my_train_X:0')
    input_y = graph.get_tensor_by_name('my_train_y:0')
    output_loss = graph.get_tensor_by_name('StatefulPartitionedCall_1:0')
    
    loss = session.run(output_loss, feed_dict={input_X: X_train[0:batch_size], 
                                               input_y: y_train[0:batch_size]})
    print('loss:', loss)
    
    input_X_serve = graph.get_tensor_by_name('my_serve_X:0')
    output_pred = graph.get_tensor_by_name('StatefulPartitionedCall:0')
    
    pred = session.run(output_pred, feed_dict={input_X_serve: X_train[0:1]})
    print('prediction:', pred)
    
    saver_filename = graph.get_tensor_by_name('saver_filename:0')
    save_op = graph.get_tensor_by_name('StatefulPartitionedCall_2:0')
    session.run(save_op, feed_dict={saver_filename: model_dir + '/variables/variables'})
    print('checkpoint saved')
    
    session.close()
    
train_predict_serve(model_dir)

If you execute `train_predict_serve()` more than once, you'll get different results since the model is training and predictions change.

The above shows that we can train, save a module and make predictions with only low level operations.  
This allows to **export blank models** with their training and serve graphs and let a 3rd party organization train it and make predictions. The exported operations are enough to let that organization train and monitor loss decrease, report accuracy on validation data sets and make inference.

*Side note*: if the module function returned two outputs, `saved_model_cli` would report them this way:
```
  The given SavedModel SignatureDef contains the following output(s):
    outputs['output_0'] tensor_info:
        dtype: DT_FLOAT
        shape: ()
        name: StatefulPartitionedCall_1:0
    outputs['output_1'] tensor_info:
        dtype: DT_FLOAT
        shape: ()
        name: StatefulPartitionedCall_1:1
```
And they could be fetched this way:  
```python
    input_X = graph.get_tensor_by_name('my_train_X:0')
    input_y = graph.get_tensor_by_name('my_train_y:0')
    output_1 = graph.get_tensor_by_name('StatefulPartitionedCall_1:0')
    output_2 = graph.get_tensor_by_name('StatefulPartitionedCall_1:1')
    
    out_val_1, out_val_2 = session.run([output_1, output_2], 
                                       feed_dict={input_X: X_train[0:1], input_y: y_train[0:1]})
```

The same exported model can be used for training and prediction in java with this code:

```java
public class TrainAndServeSavedModel {
    public static void main(String[] args) throws Exception {

        // args[0]: saved model directory
        SavedModelBundle savedModel = SavedModelBundle.load(args[0], "serve");
        Map<String, SignatureDef> signatureMap = savedModel.metaGraphDef().getSignatureDefMap();

        Tensor<TFloat32> inputTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[][] { { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f } }));
        Tensor<TFloat32> labelTensor = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[] { 1.0f }));
        
        Session session = savedModel.session();
        train(session, signatureMap.get("my_train"), inputTensor, labelTensor);
        serve(session, signatureMap.get("my_serve"),  inputTensor);
        session.close();
    }

    private static void serve(Session session, SignatureDef modelInfo, Tensor<TFloat32> inputTensor) {
        Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
        TensorInfo inputX = inputs.get("x");
        TensorInfo outputPred = modelInfo.getOutputsMap().get("output_0");

        Session.Runner runner = session.runner();
        runner.feed(inputX.getName(), inputTensor);
        TFloat32 data = runner.fetch(outputPred.getName()).run().get(0).expect(TFloat32.DTYPE).data();
        data.scalars().forEachIndexed((i, s) -> {
            System.out.println("prediction: " + s.getFloat());
        });
    }

    private static void train(Session session, SignatureDef modelInfo, Tensor<TFloat32> inputTensor, Tensor<TFloat32> labelTensor) {
        Map<String, TensorInfo> inputs = modelInfo.getInputsMap();
        TensorInfo inputX = inputs.get("X");
        TensorInfo inputY = inputs.get("y");
        TensorInfo outputLoss = modelInfo.getOutputsMap().get("output_0");

        Session.Runner runner = session.runner();
        runner.feed(inputX.getName(), inputTensor).feed(inputY.getName(), labelTensor);
        Tensor<TFloat32> loss = runner.fetch(outputLoss.getName()).run().get(0).expect(TFloat32.DTYPE);
        System.out.println("loss after training: " + loss.data().getFloat());
    }
}
```

Prints this:

```
loss after training: 1.2554951
prediction: 2.1101139
```

Tested with tensorflow 2.3.0. For more info:  
https://github.com/tensorflow/java  
https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/package-summary  


# Conclusion

We've shown how to export a training step into the SavedModel format and how to invoke it on a reloaded model in python as well as with low level operations in python or another language like java.

Other useful link about SavedModel manipulation (thanks Drew Hodun)  
https://towardsdatascience.com/how-to-extend-a-keras-model-5effc083265c