# Inference with INT8 model

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.python.compiler.tensorrt import trt_convert as trt
from tensorflow.python.saved_model import tag_constants

In [7]:
# load an image batch used for calibration of the converter
ds = tfds.load('cats_vs_dogs', split='train', batch_size=100)
batch = ds.take(1).as_numpy_iterator().next()['image']
batch = tf.cast(batch, tf.float32)
batch = (batch/127.5)-1
batch = tf.image.resize(batch, (160,160))
print('batch shape: ',batch.shape)

batch shape:  (100, 160, 160, 3)


In [8]:
model = tf.keras.models.load_model('saved_model/mobilenetv2')
print('Converting to TF-TRT INT8...')
conversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=trt.TrtPrecisionMode.INT8,
                                                               max_workspace_size_bytes=8000000000,
                                                               use_calibration=True)

converter = trt.TrtGraphConverterV2(input_saved_model_dir='saved_model/mobilenetv2',
                                    conversion_params=conversion_params)


def calibration_input_fn():
    yield (batch, )
converter.convert(calibration_input_fn=calibration_input_fn)
converter.save(output_saved_model_dir='saved_model/mobilenetv2_TFTRT_INT8')
print('Done Converting to TF-TRT INT8')

Converting to TF-TRT INT8...
INFO:tensorflow:Linked TensorRT version: (7, 0, 0)
INFO:tensorflow:Loaded TensorRT version: (7, 0, 0)
INFO:tensorflow:Assets written to: saved_model/mobilenetv2_TFTRT_INT8/assets
Done Converting to TF-TRT INT8


INFO:tensorflow:Linked TensorRT version: (7, 0, 0)
INFO:tensorflow:Loaded TensorRT version: (7, 0, 0)
INFO:tensorflow:Assets written to: saved_model/mobilenetv2_TFTRT_INT8/assets


In [9]:
!saved_model_cli show --all --dir saved_model/mobilenetv2_TFTRT_INT8

2020-06-14 21:12:14.696988: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.2

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['__saved_model_init_op']:
  The given SavedModel SignatureDef contains the following input(s):
  The given SavedModel SignatureDef contains the following output(s):
    outputs['__saved_model_init_op'] tensor_info:
        dtype: DT_INVALID
        shape: unknown_rank
        name: NoOp
  Method name is: 

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['input'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 160, 160, 3)
        name: serving_default_input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['predictions'] tensor_info:
        dtype: DT_FLOAT
        shape: unknown_rank
        name: PartitionedCall:0
  Method

In [10]:
optimized_model = tf.saved_model.load('saved_model/mobilenetv2_TFTRT_INT8',
                                      tags=[tag_constants.SERVING])
signature_keys = list(optimized_model.signatures.keys())
print('Signature keys of optimized model: ',signature_keys)
infer = optimized_model.signatures[trt.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
print('Outputs of serving_default: ', infer.structured_outputs)

Signature keys of optimized model:  ['serving_default']
Outputs of serving_default:  {'predictions': TensorSpec(shape=<unknown>, dtype=tf.float32, name='predictions')}


In [11]:
tfds.disable_progress_bar()
ds, metadata = tfds.load(
    'cats_vs_dogs',
    split='train',
    with_info=True,
    as_supervised=True)
get_label_name = metadata.features['label'].int2str
decode_prediction = lambda x: 1 if x>=0 else 0

In [14]:
for image, label in ds.take(10000):
    x = tf.cast(image, tf.float32)
    x = (x/127.5)-1
    x = tf.image.resize(x, (160,160))
    x = tf.expand_dims(x, axis=0)

    preds = infer(x)
    prediction = preds['predictions'][0,0] # only process first object at first batch index
    decoded_pred = decode_prediction(prediction)
    correct_prediction = label == decoded_pred

    print('{}[{}] image - correct prediction: {}'.format(get_label_name(label), label, correct_prediction))

dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
dog[1] image - correct prediction: True
cat[0] image - correct prediction: True
cat[0] image - correct prediction: True
