# Preparing a TensorFlow 2 model for serving with TF Serving

This notebook shows how to prepare a TensorFlow 2 model for efficient serving with **TF Serving** by exposing flexible serving signatures. 

The inputs and outputs of the model as used during model training may not be optimal for serving. For example, in a typical training pipeline, feature engineering is performed as a separate step preceding model training and hyperparameter tuning. When serving the model, it may be more optimal to embed the feature engineering logic into the serving interface rather than require a client application to preprocess data.

This notebook uses the pretrained [ResNet V2 101](https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4) image classification model from [TF Hub](https://tfhub.dev/). The demonstrated patterns and practices can be easily generalized to other types of TensorFlow 2 models.



## Setting up the environment

In [None]:
import base64
import os
import json
import requests
import time
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt

from typing import List, Optional, Text, Tuple

#tf.debugging.set_log_device_placement(True) 

### Set model store paths

Update the `GCS_MODEL_STORE` constant with a name of the GCS path to store the model created in the following sections of this notebook. 

In [None]:
GCS_MODEL_STORE = 'gs://mlops-dev-workspace/models'

### Set TF Hub URLs

In [None]:
THUB_MODEL_HANDLE = 'https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4'
IMAGENET_LABELS_URL = 'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt'

### Create a local workspace

In [None]:
LOCAL_WORKSPACE = '/tmp/workspace'
if tf.io.gfile.exists(LOCAL_WORKSPACE):
  print("Removing previous workspace artifacts...")
  tf.io.gfile.rmtree(LOCAL_WORKSPACE)

print("Creating a new workspace...")
tf.io.gfile.makedirs(LOCAL_WORKSPACE)

## Loading and running the ResNet v2 101 model

### Download and instantiate the model

In [None]:
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = 'True'

local_savedmodel_path = hub.resolve(THUB_MODEL_HANDLE)

print(local_savedmodel_path)
!ls -la {local_savedmodel_path}

In [None]:
model = hub.load(THUB_MODEL_HANDLE)

The expected input to most TF Hub TF2 image classification models, including ResNet 101, is a rank 4 tensor conforming to the following tensor specification: `tf.TensorSpec([None, height, width, 3], tf.float32)`. For the ResNet 101 model, the expected image size is `height x width = 224 x 224`. The color values for all channels are expected to be normalized to the [0, 1] range. 

The output of the model is a batch of logits vectors. The indices into the logits are the `num_classes = 1001` classes of the classification from the ImageNet dataset. The mapping from indices to class labels can be found in the [labels file](download.tensorflow.org/data/ImageNetLabels.txt) with class 0 for "background", followed by 1000 actual ImageNet classes.
    

### Test the model

We will now test the model on a couple of JPEG images. 

#### Display sample images

In [None]:
image_folder = 'test_images'
image_list = [tf.io.read_file(os.path.join(image_folder, image_path))
         for image_path in os.listdir(image_folder)]

ncolumns = len(image_list) if len(image_list) < 4 else 4
nrows = int(len(image_list) // ncolumns)
fig, axes = plt.subplots(nrows=nrows, ncols=ncolumns, figsize=(10,10))
for axis, image in zip(axes.flat[0:], image_list):
    decoded_image = tf.image.decode_image(image)
    axis.set_title(decoded_image.shape)
    axis.imshow(decoded_image.numpy())

#### Preprocess the testing images

The images need to be preprocessed to conform to the expected format.

In [None]:
def _decode_and_scale(image, size):
    image = tf.image.decode_image(image, expand_animations=False)
        
    image_height = image.shape[0]
    image_width = image.shape[1]
    crop_size = tf.minimum(image_height, image_width)
    offset_height = ((image_height - crop_size) + 1) // 2
    offset_width = ((image_width - crop_size) + 1) // 2
        
    image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size)
    image = tf.cast(tf.image.resize(image, [size, size]), tf.uint8)
    
    return image

In [None]:
size = 224

raw_images = tf.stack(image_list)
preprocessed_images = tf.map_fn(lambda x: _decode_and_scale(x, size), raw_images, dtype=tf.uint8)
preprocessed_images = tf.image.convert_image_dtype(preprocessed_images, tf.float32)
print(preprocessed_images.shape)

#### Run inference

In [None]:
predictions = model(preprocessed_images)
predictions

As we invoked the model on a batch with two images, the model returned a batch of two arrays with logits. This is not a very user friendly output so we will convert it to the list of ImageNet class labels.

#### Download ImageNet labels

In [None]:
labels_path = tf.keras.utils.get_file(
    'ImageNetLabels.txt',
    IMAGENET_LABELS_URL)
imagenet_labels = np.array(open(labels_path).read().splitlines())

#### Map the logits to class labels

We will display the 5 highest ranked labels for each image

In [None]:
for prediction in list(predictions):
    decoded = imagenet_labels[np.argsort(prediction.numpy())[::-1][:5]]
    print(list(decoded))

## Preparing the model for TF Serving

The ResNet V2 101 model from TF HUB is optimized for recomposition and fine tuning. Since there are no serving signatures in the model's metadata, it cannot be served with TF Serving as is.

In [None]:
list(model.signatures)

To make it servable, we need, at minimum, to add a serving signature describing the default inference method of the model. Since the expected inputs to our model require a relatively complex image preprocessing to be performed by the invoking client, we will also embed the preprocessing and postprocessing logic directly into the model and expose an alternative signature that accepts raw unprocessed images and returns the list of ranked class labels and associated label probabilities. 

This is achieved by defining a custom module class derived from the `tf.Module` base class, that encapsulates our ResNet model and extends it with a method implementing the image preprocessing and output postprocessing logic. The default method of the custom module is mapped to the default method of the base ResNet module to maintain the analogous interface. 

The custom module will be exported as the `SavedModel` that includes the original model, the preprocessing logic, and two serving signatures.

This technique can be generalized to other scenarios where you need to extend a TensorFlow model and you have access to the serialized `SavedModel` but you don't have access to the Python code implementing the model.


#### Define the custom serving module

In [None]:
LABELS_KEY = 'labels'
PROBABILITIES_KEY = 'probabilities'
NUM_LABELS = 5

class ServingModule(tf.Module):
    """
    A custom tf.Module that adds image preprocessing and output post processing to
    a base TF 2 image classification model from TF Hub. 
    """

    def __init__(self, base_model, input_size, output_labels):
        super(ServingModule, self).__init__()
        self._model = base_model
        self._input_size = input_size
        self._output_labels = tf.constant(output_labels, dtype=tf.string)
        

    def _decode_and_scale(self, raw_image):
        """
        Decodes, crops, and resizes a single raw image.
        """
        
        image = tf.image.decode_image(raw_image, dtype=tf.dtypes.uint8, expand_animations=False)
        image_shape = tf.shape(image)
        image_height = image_shape[0]
        image_width = image_shape[1]
        crop_size = tf.minimum(image_height, image_width)
        offset_height = ((image_height - crop_size) + 1) // 2
        offset_width = ((image_width - crop_size) + 1) // 2
        
        image = tf.image.crop_to_bounding_box(image, offset_height, offset_width, crop_size, crop_size)
        image = tf.image.resize(image, [self._input_size, self._input_size])
        image = tf.cast(image, tf.uint8)
    
        return image
    
    def _preprocess(self, raw_inputs):
        """
        Preprocesses raw inputs as sent by the client.
        """
        
        # A mitigation for https://github.com/tensorflow/tensorflow/issues/28007
        with tf.device('/cpu:0'):
            images = tf.map_fn(self._decode_and_scale, raw_inputs, dtype=tf.uint8)
        images = tf.image.convert_image_dtype(images, tf.float32)
        
        return images
        
    def _postprocess(self, model_outputs):
        """
        Postprocesses outputs returned by the base model.
        """
        
        probabilities = tf.nn.softmax(model_outputs)
        indices = tf.argsort(probabilities, axis=1, direction='DESCENDING')
        
        return {
            LABELS_KEY: tf.gather(self._output_labels, indices, axis=-1)[:,:NUM_LABELS],
            PROBABILITIES_KEY: tf.sort(probabilities, direction='DESCENDING')[:,:NUM_LABELS]
        }
        

    @tf.function(input_signature=[tf.TensorSpec([None, 224, 224, 3], tf.float32)])
    def __call__(self, x):
        """
        A pass-through to the base model.
        """
        
        return self._model(x)

    @tf.function(input_signature=[tf.TensorSpec([None], tf.string)])
    def predict_labels(self, raw_images):
        """
        Preprocesses inputs, calls the base model 
        and postprocesses outputs from the base model.
        """
        
        # Call the preprocessing handler
        images = self._preprocess(raw_images)
        
        # Call the base model
        logits = self._model(images)
        
        # Call the postprocessing handler
        outputs = self._postprocess(logits)
        
        return outputs
        
    
serving_module = ServingModule(model, 224, imagenet_labels)

#### Test the custom serving module

In [None]:
predictions = serving_module.predict_labels(raw_images)
predictions

#### Save the custom serving module as `SavedModel`

In [None]:
model_name = 'resnet_serving'
model_version = '1'
model_path = os.path.join(LOCAL_WORKSPACE, model_name, model_version)

default_signature = serving_module.__call__.get_concrete_function()
preprocess_signature = serving_module.predict_labels.get_concrete_function()
signatures = {
    'serving_default': default_signature,
    'serving_preprocess': preprocess_signature
}

tf.saved_model.save(serving_module, model_path, signatures=signatures)

#### Inspect the `SavedModel`

In [None]:
!saved_model_cli show --dir {model_path} --tag_set serve --all

## Test the custom serving module with TF Serving

We will now run the exported custom serving module using the TF Serving docker image.

### Start the TF Serving server:

To start the TF Serving server and expose the TF Serving REST API port (8501):

1. Open a JupyterLab terminal.
2. Pull the latest TF Serving docker image.

```
docker pull tensorflow/serving
```


4. Set the environment variables for the model's name and the model's path. Use the value printed by the next cell for the model's path.

```
export MODEL_PATH=[YOUR_MODEL_PATH]
export MODEL_NAME=resnet
```

In [None]:
print(model_path[:-2])

3. Start the server and register your custom module for serving

```
docker run -it --rm -p 8501:8501 \
-v $MODEL_PATH:/models/$MODEL_NAME \
-e MODEL_NAME=$MODEL_NAME \
tensorflow/serving 
```



### Run inference

We will now run inference by invoking the TF Serving `Predict` API.

Refer to the [TF Serving REST API Reference](https://www.tensorflow.org/tfx/serving/api_rest) for more information about the API format.

#### Invoke the model using the `serving_preprocess` signature

Load sample images

In [None]:
image_folder = 'test_images'
image_list = [tf.io.read_file(os.path.join(image_folder, image_path))
         for image_path in os.listdir(image_folder)]

Encode the images using base 64 encoding.

In [None]:
raw_images_as_list = [{'b64': base64.b64encode(image.numpy()).decode('utf-8')} for image in image_list]

Prepare the request body and headers.

In [None]:
request_body = {
    'signature_name': 'serving_preprocess',
    'instances': raw_images_as_list
}

headers = {"content-type": "application/json"}

Invoke the `Predict` endpoint and display results.

In [None]:
uri = 'http://localhost:8501/v1/models/resnet:predict'

response = requests.post(uri, data=json.dumps(request_body))

response.json()

#### Invoke the model using the `serving_default` signature

Preprocess the raw images.

In [None]:
size = 224

raw_images = tf.stack(image_list)
preprocessed_images = tf.map_fn(lambda x: _decode_and_scale(x, size), raw_images, dtype=tf.uint8)
preprocessed_images = tf.image.convert_image_dtype(preprocessed_images, tf.float32)
print(preprocessed_images.shape)

Prepare the request body and headers.

In [None]:
images_as_list = preprocessed_images.numpy().tolist()

request_body = {
    'signature_name': 'serving_default',
    'instances': images_as_list
}

headers = {"content-type": "application/json"}

Invoke the `Predict` endpoint and display results.

In [None]:
uri = 'http://localhost:8501/v1/models/resnet:predict'

response = requests.post(uri, data=json.dumps(request_body))

print(str(response.json())[0:200], '...')

### Stop the TF Serving container

1. Open a new JupyterLab terminal.
2. Get the ID of your TF Serving container

```
docker ps
```
3. Terminate your TF Serving container

```
docker kill [YOUR_CONTAINER_ID]
```


## Copy the custom serving module to GCS

In [None]:
local_path = os.path.join(LOCAL_WORKSPACE, model_name)

!gsutil cp -r {local_path} {GCS_MODEL_STORE}

In [None]:
!gsutil ls {GCS_MODEL_STORE}

## Next Steps

Walk through the `aipp_deploy.ipynb` notebook to learn how to deploy the custom serving module created in this notebook to **AI Platform Prediction** using TF Serving container image.

## License

<font size=-1>Licensed under the Apache License, Version 2.0 (the \"License\");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at [https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the License for the specific language governing permissions and limitations under the License.</font>