# Inference with saved TRT plan based on INT8 model

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants

In [4]:
# create an iterator for image batches in order calibrate and build the converter
ds = tfds.load('cats_vs_dogs', split='train', batch_size=64)
ds_iter = iter(ds)

In [6]:
model = tf.keras.models.load_model('saved_model/mobilenetv2')
print('Creating TensorRT Plan...')
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=trt.TrtPrecisionMode.INT8,
                                                               max_workspace_size_bytes=8000000000,
                                                               use_calibration=True)

converter = trt.TrtGraphConverterV2(input_saved_model_dir='saved_model/mobilenetv2',
                                    conversion_params=conversion_params)


def input_fn():
    for _ in range(16):
        batch = next(ds_iter)['image']
        batch = tf.cast(batch, tf.float32)
        batch = (batch/127.5)-1
        batch = tf.image.resize(batch, (160,160))
        yield (batch,)

converter.convert(calibration_input_fn=input_fn)
converter.build(input_fn=input_fn)
converter.save(output_saved_model_dir='saved_model/mobilenetv2_TFTRT_INT8_built')
optimized_model = tf.saved_model.load('saved_model/mobilenetv2_TFTRT_INT8_built', tags=[tag_constants.SERVING])
graph_func = optimized_model.signatures[trt.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
trt_graph_def = graph_func.graph.as_graph_def()
for n in trt_graph_def.node:
    if n.op == "TRTEngineOp":
        print("Node: %s, %s" % (n.op, n.name.replace("/", "_")))
        with tf.io.gfile.GFile("%s.plan" % (n.name.replace("/", "_")), 'wb') as f:
            f.write(n.attr["serialized_segment"].s)
    else:
        print("Exclude Node: %s, %s" % (n.op, n.name.replace("/", "_")))

print('Done Creating TensorRT Plan')

Creating TensorRT Plan...
INFO:tensorflow:Linked TensorRT version: (7, 0, 0)
INFO:tensorflow:Loaded TensorRT version: (7, 0, 0)
INFO:tensorflow:Assets written to: saved_model/mobilenetv2_TFTRT_INT8_built/assets
Exclude Node: Placeholder, input
Exclude Node: PartitionedCall, PartitionedCall
Exclude Node: Identity, Identity
Done Creating TensorRT Plan


INFO:tensorflow:Linked TensorRT version: (7, 0, 0)
INFO:tensorflow:Loaded TensorRT version: (7, 0, 0)
INFO:tensorflow:Assets written to: saved_model/mobilenetv2_TFTRT_INT8_built/assets


In [8]:
for n in trt_graph_def.node:
    if n.op == "TRTEngineOp":
        print("Node: %s, %s" % (n.op, n.name.replace("/", "_")))
        with tf.io.gfile.GFile("%s.plan" % (n.name.replace("/", "_")), 'wb') as f:
            f.write(n.attr["serialized_segment"].s)
    else:
        print("Exclude Node: %s, %s" % (n.op, n.name.replace("/", "_")))


Exclude Node: Placeholder, input
Exclude Node: PartitionedCall, PartitionedCall
Exclude Node: Identity, Identity


In [None]:
!saved_model_cli show --all --dir saved_model/mobilenetv2_TFTRT_INT8_built

In [10]:
optimized_model = tf.saved_model.load('saved_model/mobilenetv2_TFTRT_INT8',
                                      tags=[tag_constants.SERVING])
signature_keys = list(optimized_model.signatures.keys())
print('Signature keys of optimized model: ',signature_keys)
infer = optimized_model.signatures['serving_default']
print('Outputs of serving_default: ', infer.structured_outputs)

Signature keys of optimized model:  ['serving_default']
Outputs of serving_default:  {'predictions': TensorSpec(shape=<unknown>, dtype=tf.float32, name='predictions')}


In [11]:
tfds.disable_progress_bar()
ds, metadata = tfds.load(
    'cats_vs_dogs',
    split='train',
    with_info=True,
    as_supervised=True)
get_label_name = metadata.features['label'].int2str
decode_prediction = lambda x: 1 if x>=0 else 0

In [14]:
for image, label in ds.take(10000):
    x = tf.cast(image, tf.float32)
    x = (x/127.5)-1
    x = tf.image.resize(x, (160,160))
    x = tf.expand_dims(x, axis=0)

    preds = infer(x)
    prediction = preds['predictions'][0,0] # only process first object at first batch index
    decoded_pred = decode_prediction(prediction)
    correct_prediction = label == decoded_pred

    print('{}[{}] image - correct prediction: {}'.format(get_label_name(label), label, correct_prediction))

dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
