This notebook is for timing TRT model inference vs non-TRT model inference

In [1]:
%%bash

echo 'Convert TRT:'
CONVERT_TRT_DIR=../../../tmp/trt_end_to_end/convert/trt
find ${CONVERT_TRT_DIR} -type f | sort

Convert TRT:
../../../tmp/trt_end_to_end/convert/trt/basic/001/basic_epoch001_2019-09-03T19:15_trt.pb
../../../tmp/trt_end_to_end/convert/trt/batchn/001/batchn_epoch001_2019-09-03T19:28_trt.pb
../../../tmp/trt_end_to_end/convert/trt/conv/001/conv_epoch001_2019-09-03T19:30_trt.pb
../../../tmp/trt_end_to_end/convert/trt/resnet50/001/resnet50_epoch001_2019-09-03T19:31_trt.pb


## Setup

In [2]:
import sys
sys.path.append('../../..')

## Parameters

In [3]:
_NAME = 'resnet50'
_EPOCH = 1

In [4]:
from src.utils.trt_end_to_end_constants import *
_NAME, _EPOCH, _TIME = get_params(_NAME, _EPOCH)

Metadata

In [5]:
from src.utils.trt_end_to_end_constants import MD_FILE_FORMAT

md_filename = MD_FILE_FORMAT % (_NAME, _EPOCH, _TIME)
_train_dir = get_train_dir(_NAME, _EPOCH)
md_filepath = os.path.join(_train_dir, md_filename)

from src.meta.metadata import Metadata
ret, metadata = Metadata.from_md(md_filepath)
assert ret == 0

from pprint import pprint
pprint(vars(metadata))

{'epoch': 1,
 'input_names': ['resnet50_input:0'],
 'name': 'resnet50',
 'output_names': ['fc100/Softmax:0']}


## Data

In [6]:
from src.data.cifar100 import CLASSES, INPUT_SHAPE, load_data
import numpy as np

(train_images, train_labels), (test_images, test_labels) = load_data()
input_img = np.array([train_images[0]])

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [7]:
input_img.shape

(1, 32, 32, 3)

## Helper functions

In [8]:
import tensorflow as tf
from tensorflow.python.platform import gfile
import time

# function to read a ".pb" model 
# (can be used to read frozen model or TensorRT model)
def read_pb_graph(model):
    with gfile.FastGFile(model,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    return graph_def

## Test Params

In [9]:
_WARM_UP_TRIALS = 10000
_TRIALS = 10000

## TensorRT

In [10]:
trt_graph_dir = get_trt_graph_dir(_NAME, _EPOCH)
trt_graph_filename = TRT_GRAPH_FILE_FORMAT % (_NAME, _EPOCH, _TIME)
trt_graph_filepath = os.path.join(trt_graph_dir, trt_graph_filename)
print(trt_graph_filepath)

../../../tmp/trt_end_to_end/convert/trt/resnet50/001/resnet50_epoch001_2019-09-03T19:31_trt.pb


In [11]:
# variable

# required import or it will fail with weird message
import tensorflow.contrib.tensorrt as trt


graph = tf.Graph()
with graph.as_default():
    with tf.Session(config=tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.50))) as sess:
        # read TensorRT model
        trt_graph = read_pb_graph(trt_graph_filepath)

        # obtain the corresponding input-output tensor
        tf.import_graph_def(trt_graph, name='')
        input = sess.graph.get_tensor_by_name(metadata.input_names[0])
        output = sess.graph.get_tensor_by_name(metadata.output_names[0])
        
        # warmup
        n = _WARM_UP_TRIALS
        print("Warming up for %d trials..." % n)
        for i in range(n):
            out_pred = sess.run(output, feed_dict={input: input_img})

        # test
        total_time = 0
        n = _TRIALS
        print("Testing for %d trials..." % n)
        for i in range(n):
            t1 = time.time()
            out_pred = sess.run(output, feed_dict={input: input_img})
            delta_time = time.time() - t1
            total_time += delta_time

        avg_time_tensorRT = total_time / n
        print("TRT avg time: %ss" % avg_time_tensorRT)

W0903 20:06:21.045108 139675237521216 deprecation.py:323] From <ipython-input-8-3a99888420c0>:8: FastGFile.__init__ (from tensorflow.python.platform.gfile) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.gfile.GFile.


Warming up for 10000 trials...
Testing for 10000 trials...
TRT avg time: 0.0022118839502334596s


## Frozen Graph

In [12]:
frozen_graph_dir = get_frozen_graph_dir(_NAME, _EPOCH)

frozen_graph_filename = FROZEN_GRAPH_FILE_FORMAT % (_NAME, _EPOCH, _TIME)
frozen_graph_filepath = os.path.join(frozen_graph_dir, frozen_graph_filename)
print(frozen_graph_filepath)

../../../tmp/trt_end_to_end/convert/tf/frozen/resnet50/001/resnet50_epoch001_2019-09-03T19:31_frozen.pb


In [13]:
graph = tf.Graph()
with graph.as_default():
    with tf.Session() as sess:
        # read TensorRT model
        frozen_graph = read_pb_graph(frozen_graph_filepath)

        # obtain the corresponding input-output tensor
        tf.import_graph_def(frozen_graph, name='')
        input = sess.graph.get_tensor_by_name(metadata.input_names[0])
        output = sess.graph.get_tensor_by_name(metadata.output_names[0])
        
        # warmup
        n = _WARM_UP_TRIALS
        print("Warming up for %d trials..." % n)
        for i in range(n):
            out_pred = sess.run(output, feed_dict={input: input_img})

        # test
        total_time = 0
        n = _TRIALS
        print("Testing for %d trials..." % n)
        for i in range(n):
            t1 = time.time()
            out_pred = sess.run(output, feed_dict={input: input_img})
            delta_time = time.time() - t1
            total_time += delta_time
            
        avg_time_original_model = total_time / n
        print("Old avg time: %ss" % avg_time_original_model)

Warming up for 10000 trials...
Testing for 10000 trials...
Old avg time: 0.0054253507852554325s


In [14]:
print("Improvement:", avg_time_original_model/avg_time_tensorRT)

Improvement: 2.4528189124401387
