-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrt_convert.py
1851 lines (1584 loc) · 77.2 KB
/
trt_convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Exposes the Python wrapper conversion to trt_graph."""
import collections
from functools import partial # pylint: disable=g-importing-member
import os
import platform
import sys
import tempfile
import numpy as np
import six as _six
from tensorflow.core.framework import variable_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.compiler.tensorrt import utils as trt_utils
from tensorflow.python.eager import context
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.trackable import asset
from tensorflow.python.trackable import autotrackable
from tensorflow.python.trackable import resource
from tensorflow.python.training import saver
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Lazily load the op, since it's not available in cpu-only builds. Importing
# this at top will cause tests that imports TF-TRT fail when they're built
# and run without CUDA/GPU.
gen_trt_ops = LazyLoader(
"gen_trt_ops", globals(),
"tensorflow.compiler.tf2tensorrt.ops.gen_trt_ops")
_pywrap_py_utils = LazyLoader(
"_pywrap_py_utils", globals(),
"tensorflow.compiler.tf2tensorrt._pywrap_py_utils")
# Register TRT ops in python, so that when users import this module they can
# execute a TRT-converted graph without calling any of the methods in this
# module.
#
# This will call register_op_list() in
# tensorflow/python/framework/op_def_registry.py, but it doesn't register
# the op or the op kernel in C++ runtime.
try:
gen_trt_ops.trt_engine_op # pylint: disable=pointless-statement
except AttributeError:
pass
def _to_bytes(s):
"""Encode s if it is a sequence of chars."""
if isinstance(s, _six.text_type):
return s.encode("utf-8", errors="surrogateescape")
return s
def _to_string(s):
"""Decode s if it is a sequence of bytes."""
if isinstance(s, _six.binary_type):
return s.decode("utf-8")
return s
class TrtPrecisionMode(object):
FP32 = "FP32"
FP16 = "FP16"
INT8 = "INT8"
@staticmethod
def supported_precision_modes():
precisions = [
TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8
]
return precisions + [p.lower() for p in precisions]
# Use a large enough number as the default max_workspace_size for TRT engines,
# so it can produce reasonable performance results with the default.
# For TRT >= 8.4, the recommendation is MAX_INT.
if (_pywrap_py_utils.is_tensorrt_enabled() and
trt_utils.is_loaded_tensorrt_version_greater_equal(8, 4, 0)):
# We must use `sys.maxsize - 512` to avoid overflow during casting.
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = sys.maxsize - 512
else:
DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30 # 1,073,741,824
PROFILE_STRATEGY_RANGE = "Range"
PROFILE_STRATEGY_OPTIMAL = "Optimal"
PROFILE_STRATEGY_RANGE_OPTIMAL = "Range+Optimal"
PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE = "ImplicitBatchModeCompatible"
def supported_profile_strategies():
return [
PROFILE_STRATEGY_RANGE, PROFILE_STRATEGY_OPTIMAL,
PROFILE_STRATEGY_RANGE_OPTIMAL,
PROFILE_STRATEGY_IMPLICIT_BATCH_MODE_COMPATIBLE
]
@tf_export("experimental.tensorrt.ConversionParams", v1=[])
class TrtConversionParams(
collections.namedtuple("TrtConversionParams", [
"max_workspace_size_bytes", "precision_mode", "minimum_segment_size",
"maximum_cached_engines", "use_calibration", "allow_build_at_runtime"
])):
"""Parameters that are used for TF-TRT conversion.
Fields:
max_workspace_size_bytes: the maximum GPU temporary memory that the TRT
engine can use at execution time. This corresponds to the
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of the strings in
TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
maximum_cached_engines: max number of cached TRT engines for dynamic TRT
ops. Created TRT engines for a dynamic dimension are cached. If the
number of cached engines is already at max but none of them supports the
input shapes, the TRTEngineOp will fall back to run the original TF
subgraph that corresponds to the TRTEngineOp.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph by running calibration with calibrate(). If set to False,
quantization nodes will be expected for every tensor in the graph
(excluding those which will be fused). If a range is missing, an error
will occur. Please note that accuracy may be negatively affected if
there is a mismatch between which tensors TRT quantizes and which
tensors were trained with fake quantization.
allow_build_at_runtime: whether to allow building TensorRT engines during
runtime if no prebuilt TensorRT engine can be found that can handle the
given inputs during runtime, then a new TensorRT engine is built at
runtime if allow_build_at_runtime=True, and otherwise native TF is used.
"""
def __new__(cls,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
maximum_cached_engines=1,
use_calibration=True,
allow_build_at_runtime=True):
return super(TrtConversionParams,
cls).__new__(cls, max_workspace_size_bytes, precision_mode,
minimum_segment_size, maximum_cached_engines,
use_calibration, allow_build_at_runtime)
DEFAULT_TRT_CONVERSION_PARAMS = TrtConversionParams()
_TRT_ENGINE_OP_NAME = "TRTEngineOp"
def _check_conversion_params(conversion_params, is_v2=False):
"""Validate the provided TrtConversionParams.
Args:
conversion_params: a TrtConversionParams instance.
is_v2: whether we're getting a RewriterConfig for TF 2.0.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
supported_precision_modes = TrtPrecisionMode.supported_precision_modes()
if conversion_params.precision_mode not in supported_precision_modes:
raise ValueError(
("precision mode '{}' is not supported."
"It should be one of {}").format(conversion_params.precision_mode,
supported_precision_modes))
if (conversion_params.minimum_segment_size <= 0 and
conversion_params.minimum_segment_size != -1):
raise ValueError("minimum segment size should be positive or -1 "
"(to disable main graph conversion).")
def _check_trt_version_compatibility():
"""Check compatibility of TensorRT version.
Raises:
RuntimeError: if the TensorRT library version is incompatible.
"""
if not _pywrap_py_utils.is_tensorrt_enabled():
logging.error(
"Tensorflow needs to be built with TensorRT support enabled to allow "
"TF-TRT to operate.")
raise RuntimeError("Tensorflow has not been built with TensorRT support.")
if platform.system() == "Windows":
logging.warn(
"Windows support is provided experimentally. No guarantee is made "
"regarding functionality or engineering support. Use at your own risk.")
linked_version = _pywrap_py_utils.get_linked_tensorrt_version()
loaded_version = _pywrap_py_utils.get_loaded_tensorrt_version()
logging.info("Linked TensorRT version: %s", str(linked_version))
logging.info("Loaded TensorRT version: %s", str(loaded_version))
def raise_trt_version_deprecated(version_type, trt_version):
assert version_type in [
"linked", "loaded"
], ("Incorrect value received for version_type: %s. Accepted: ['linked', "
"'loaded']") % version_type
logging.error(
"The {version_type} version of TensorRT: `{trt_version}` has now "
"been removed. Please upgrade to TensorRT 7 or more recent.".format(
version_type=version_type,
trt_version=trt_utils.version_tuple_to_string(trt_version)))
raise RuntimeError("Incompatible %s TensorRT versions" % version_type)
if not trt_utils.is_linked_tensorrt_version_greater_equal(7, 0, 0):
raise_trt_version_deprecated("linked", linked_version)
if not trt_utils.is_loaded_tensorrt_version_greater_equal(7, 0, 0):
raise_trt_version_deprecated("loaded", loaded_version)
if (loaded_version[0] != linked_version[0] or
not trt_utils.is_loaded_tensorrt_version_greater_equal(*linked_version)):
logging.error(
"Loaded TensorRT %s but linked TensorFlow against TensorRT %s. A few "
"requirements must be met:\n"
"\t-It is required to use the same major version of TensorRT during "
"compilation and runtime.\n"
"\t-TensorRT does not support forward compatibility. The loaded "
"version has to be equal or more recent than the linked version.",
trt_utils.version_tuple_to_string(loaded_version),
trt_utils.version_tuple_to_string(linked_version))
raise RuntimeError("Incompatible TensorRT major version")
elif loaded_version != linked_version:
logging.info(
"Loaded TensorRT %s and linked TensorFlow against TensorRT %s. This is "
"supported because TensorRT minor/patch upgrades are backward "
"compatible.", trt_utils.version_tuple_to_string(loaded_version),
trt_utils.version_tuple_to_string(linked_version))
def _get_tensorrt_rewriter_config(conversion_params,
is_dynamic_op=None,
max_batch_size=None,
is_v2=False,
disable_non_trt_optimizers=False,
use_implicit_batch=True,
profile_strategy=PROFILE_STRATEGY_RANGE):
"""Returns a RewriterConfig proto for TRT transformation.
Args:
conversion_params: a TrtConversionParams instance.
is_dynamic_op: whether to use dynamic engines.
max_batch_size: maximum batch size for static engines.
is_v2: whether we're getting a RewriterConfig for TF 2.0.
disable_non_trt_optimizers: Turn off all default Grappler optimizers.
use_implicit_batch: Whether to use implicit batch or explicit batch.
profile_strategy: dynamic shape optimization profile strategy.
Returns:
A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
Raises:
TypeError: if any of the parameters are of unexpected type.
ValueError: if any of the parameters are of unexpected value.
"""
_check_conversion_params(conversion_params, is_v2=is_v2)
if is_v2 and is_dynamic_op is not None and not is_dynamic_op:
raise ValueError("is_dynamic_op is either None or True for TF2")
if not is_v2 and is_dynamic_op is None:
raise ValueError("is_dynamic_op can't be None for TF1")
if (is_dynamic_op is None or is_dynamic_op) and max_batch_size is not None:
raise ValueError("max_batch_size has to be None for TF2"
" or when is_dynamic_op == True in TF1")
if is_dynamic_op is not None and not is_dynamic_op and not isinstance(
max_batch_size, int):
raise ValueError(
"max_batch_size has to be an integer for is_dynamic_op==False in TF1")
rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
# Disable Grappler Remapper to avoid that fused OPs that may not be
# beneficial to TF-TRT and are not supported by TF-TRT.
rewriter_config_with_trt.remapping = False
# Prevent folding of Const->QDQ chains.
rewriter_config_with_trt. \
experimental_disable_folding_quantization_emulation = (
trt_utils.is_linked_tensorrt_version_greater_equal(8, 0, 0) or
trt_utils.is_loaded_tensorrt_version_greater_equal(8, 0, 0))
if not disable_non_trt_optimizers:
rewriter_config_with_trt.optimizers.extend([
"pruning", "debug_stripper", "layout", "dependency", "constfold",
"common_subgraph_elimination"
])
rewriter_config_with_trt.meta_optimizer_iterations = (
rewriter_config_pb2.RewriterConfig.ONE)
optimizer = rewriter_config_with_trt.custom_optimizers.add()
if not disable_non_trt_optimizers:
# Add a constfold optimizer to cleanup the unused Const nodes.
rewriter_config_with_trt.custom_optimizers.add().name = "constfold"
optimizer.name = "TensorRTOptimizer"
optimizer.parameter_map[
"minimum_segment_size"].i = conversion_params.minimum_segment_size
optimizer.parameter_map["max_workspace_size_bytes"].i = (
conversion_params.max_workspace_size_bytes)
optimizer.parameter_map["precision_mode"].s = _to_bytes(
conversion_params.precision_mode)
optimizer.parameter_map[
"maximum_cached_engines"].i = conversion_params.maximum_cached_engines
optimizer.parameter_map[
"use_calibration"].b = conversion_params.use_calibration
optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
optimizer.parameter_map[
"allow_build_at_runtime"].b = conversion_params.allow_build_at_runtime
if max_batch_size is not None:
optimizer.parameter_map["max_batch_size"].i = max_batch_size
optimizer.parameter_map["use_implicit_batch"].b = use_implicit_batch
# While we accept case insensitive strings from the users, we only pass the
# strings in lower cases to TF-TRT converter.
if not use_implicit_batch:
optimizer.parameter_map["profile_strategy"].s = _to_bytes(
profile_strategy.lower())
# Disabling optimizers should happen after defining the TF-TRT grappler pass
# otherwise the template can overwrite the disablement.
if disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
rewriter_config_with_trt)
return rewriter_config_with_trt
@deprecation.deprecated(
None, "You shouldn't need a rewriter_config with the current TF-TRT APIs.")
def get_tensorrt_rewriter_config(conversion_params,
is_dynamic_op=None,
max_batch_size=None,
is_v2=False,
disable_non_trt_optimizers=False):
return _get_tensorrt_rewriter_config(conversion_params, is_dynamic_op,
max_batch_size, is_v2,
disable_non_trt_optimizers)
# Remove all scope prefixes in the node name. In TF 2.0, the same concrete
# function can be initialized multiple times with different prefixes, and
# this will result in the same TRTEngineOp being initialized multiple times
# with different cache and duplicate TRT engines.
# TODO(laigd): this may be caused by the fact that TRTEngineOp is not
# stateful, need to investigate.
# TODO(laigd): we rely on the fact that all functions are fully inlined
# before TF-TRT optimizer is called, as otherwise it may generate the same
# name when optimizing a different function graph. Fix this.
def _get_canonical_engine_name(name):
return name.split("/")[-1]
class TrtGraphConverter(object):
"""A converter for TF-TRT transformation for TF 1.x GraphDef/SavedModels.
To run the conversion without quantization calibration (e.g. for FP32/FP16
precision modes):
```python
converter = TrtGraphConverter(
input_saved_model_dir="my_dir",
precision_mode=TrtPrecisionMode.FP16)
converted_graph_def = converter.convert()
converter.save(output_saved_model_dir)
```
To run the conversion with quantization calibration:
```python
converter = TrtGraphConverter(
input_saved_model_dir="my_dir",
precision_mode=TrtPrecisionMode.INT8)
converter.convert()
# Run calibration 10 times.
converted_graph_def = converter.calibrate(
fetch_names=['output:0'],
num_runs=10,
feed_dict_fn=lambda: {'input:0': my_next_data()})
converter.save(output_saved_model_dir)
```
"""
def __init__(self,
input_saved_model_dir=None,
input_saved_model_tags=None,
input_saved_model_signature_key=None,
input_graph_def=None,
nodes_denylist=None,
max_batch_size=1,
max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
use_calibration=True):
"""Initializes the converter.
Args:
input_saved_model_dir: the directory to load the SavedModel which contains
the input graph to transforms. Used only when input_graph_def is None.
input_saved_model_tags: list of tags to load the SavedModel.
input_saved_model_signature_key: the key of the signature to optimize the
graph for.
input_graph_def: a GraphDef object containing a model to be transformed.
If set to None, the graph will be read from the SavedModel loaded from
input_saved_model_dir.
nodes_denylist: list of node names to prevent the converter from touching.
max_batch_size: max size for the input batch.
max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
engine can use at execution time. This corresponds to the
'workspaceSize' parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph
to be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the
TRT network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT
ops. If the number of cached engines is already at max but none of them
can serve the input, the TRTEngineOp will fall back to run the TF
function based on which the TRTEngineOp is created.
use_calibration: this argument is ignored if precision_mode is not INT8.
If set to True, a calibration graph will be created to calibrate the
missing ranges. The calibration graph must be converted to an inference
graph by running calibration with calibrate(). If set to False,
quantization nodes will be expected for every tensor in the graph
(excluding those which will be fused). If a range is missing, an error
will occur. Please note that accuracy may be negatively affected if
there is a mismatch between which tensors TRT quantizes and which
tensors were trained with fake quantization.
Raises:
ValueError: if the combination of the parameters is invalid.
RuntimeError: if this class is used in TF 2.0.
"""
if context.executing_eagerly():
raise RuntimeError(
"Please use tf.experimental.tensorrt.Converter in TF 2.0.")
if input_graph_def and input_saved_model_dir:
raise ValueError(
"Can only specify one of input_graph_def and input_saved_model_dir")
if not input_graph_def and not input_saved_model_dir:
raise ValueError("Must specify one of input_graph_def and "
"input_saved_model_dir")
_check_trt_version_compatibility()
self._input_graph_def = input_graph_def
self._nodes_denylist = nodes_denylist
self._input_saved_model_dir = input_saved_model_dir
self._converted = False
self._grappler_meta_graph_def = None
self._input_saved_model_tags = (
input_saved_model_tags or [tag_constants.SERVING])
self._input_saved_model_signature_key = (
input_saved_model_signature_key or
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
# For calibration usage.
self._calibration_graph = None
self._calibration_data_collected = False
self._need_calibration = (
((precision_mode == TrtPrecisionMode.INT8) or
(precision_mode == TrtPrecisionMode.INT8.lower())) and use_calibration)
if self._need_calibration and not is_dynamic_op:
logging.warn(
"INT8 precision mode with calibration is supported with "
"dynamic TRT ops only. Disregarding is_dynamic_op parameter.")
is_dynamic_op = True
self._is_dynamic_op = is_dynamic_op
if is_dynamic_op:
self._max_batch_size = None
if max_batch_size is not None:
logging.warn("When is_dynamic_op==True max_batch_size should be None")
else:
if not isinstance(max_batch_size, int):
raise ValueError("When is_dynamic_op==False max_batch_size should be "
"an integer")
self._max_batch_size = max_batch_size
self._conversion_params = TrtConversionParams(
max_workspace_size_bytes=max_workspace_size_bytes,
precision_mode=precision_mode,
minimum_segment_size=minimum_segment_size,
maximum_cached_engines=maximum_cached_engines,
use_calibration=use_calibration,
allow_build_at_runtime=True)
_check_conversion_params(self._conversion_params)
self._test_only_disable_non_trt_optimizers = False
def _run_conversion(self):
"""Run Grappler's OptimizeGraph() tool to convert the graph."""
# Create custom ConfigProto for Grappler.
grappler_session_config = config_pb2.ConfigProto()
custom_rewriter_config = _get_tensorrt_rewriter_config(
conversion_params=self._conversion_params,
is_dynamic_op=self._is_dynamic_op,
max_batch_size=self._max_batch_size,
disable_non_trt_optimizers=self._test_only_disable_non_trt_optimizers,
use_implicit_batch=True)
grappler_session_config.graph_options.rewrite_options.CopyFrom(
custom_rewriter_config)
# Run Grappler.
self._converted_graph_def = tf_optimizer.OptimizeGraph(
grappler_session_config,
self._grappler_meta_graph_def,
graph_id=b"tf_graph")
self._converted = True
def _add_nodes_denylist(self):
if self._nodes_denylist:
collection_def = self._grappler_meta_graph_def.collection_def["train_op"]
denylist = collection_def.node_list.value
for i in self._nodes_denylist:
if isinstance(i, ops.Tensor):
denylist.append(_to_bytes(i.name))
else:
denylist.append(_to_bytes(i))
def _convert_graph_def(self):
"""Convert the input GraphDef."""
graph = ops.Graph()
with graph.as_default():
importer.import_graph_def(self._input_graph_def, name="")
self._grappler_meta_graph_def = saver.export_meta_graph(
graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
self._add_nodes_denylist()
self._run_conversion()
def _collections_to_keep(self, collection_keys):
# TODO(laigd): currently we use the collection key to filter out
# collections that depend on variable ops, but this may miss some
# other user-defined collections. A better way would be to use
# CollectionDef::NodeList for the filtering.
collections_to_remove = (
ops.GraphKeys._VARIABLE_COLLECTIONS + [
ops.GraphKeys.TRAIN_OP, ops.GraphKeys.WHILE_CONTEXT,
ops.GraphKeys.COND_CONTEXT
])
return [key for key in collection_keys if key not in collections_to_remove]
def _convert_saved_model(self):
"""Convert the input SavedModel."""
graph = ops.Graph()
with session.Session(graph=graph) as sess:
input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
self._input_saved_model_dir)
input_signature_def = input_meta_graph_def.signature_def[
self._input_saved_model_signature_key]
def _gather_names(tensor_info):
"""Get the node names from a TensorInfo."""
return {tensor_info[key].name.split(":")[0] for key in tensor_info}
# Get input and outputs from all SignatureDef.
output_node_names = _gather_names(input_signature_def.inputs).union(
_gather_names(input_signature_def.outputs))
# Preserve nodes in collection
for collection_key in self._collections_to_keep(
input_meta_graph_def.collection_def):
for op in sess.graph.get_collection(collection_key):
if isinstance(op, ops.Operation):
output_node_names.add(op.name.split(":")[0])
# Freeze the variables in the SavedModel graph and copy the frozen
# graph over.
frozen_graph_def = convert_to_constants.convert_variables_to_constants(
sess, sess.graph.as_graph_def(add_shapes=True),
list(output_node_names))
self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
# Copy the collections that are not variables.
for collection_key in self._collections_to_keep(
input_meta_graph_def.collection_def):
self._grappler_meta_graph_def.collection_def[collection_key].CopyFrom(
input_meta_graph_def.collection_def[collection_key])
self._add_nodes_denylist()
# Copy other information.
self._grappler_meta_graph_def.meta_info_def.CopyFrom(
input_meta_graph_def.meta_info_def)
self._grappler_meta_graph_def.signature_def[
self._input_saved_model_signature_key].CopyFrom(input_signature_def)
# TODO(laigd): maybe add back AssetFileDef.
self._run_conversion()
def convert(self):
"""Run the TF-TRT conversion.
Returns:
The converted GraphDef for TF 1.x.
"""
assert not self._converted
if self._input_graph_def:
self._convert_graph_def()
else:
self._convert_saved_model()
return self._converted_graph_def
def calibrate(self,
fetch_names,
num_runs,
feed_dict_fn=None,
input_map_fn=None):
"""Run the calibration and return the calibrated GraphDef.
Args:
fetch_names: a list of output tensor name to fetch during calibration.
num_runs: number of runs of the graph during calibration.
feed_dict_fn: a function that returns a dictionary mapping input names (as
strings) in the GraphDef to be calibrated to values (e.g. Python list,
numpy arrays, etc). One and only one of `feed_dict_fn` and
`input_map_fn` should be specified.
input_map_fn: a function that returns a dictionary mapping input names (as
strings) in the GraphDef to be calibrated to Tensor objects. The values
of the named input tensors in the GraphDef to be calibrated will be
re-mapped to the respective `Tensor` values during calibration. One and
only one of `feed_dict_fn` and `input_map_fn` should be specified.
Raises:
ValueError: if the input combination is invalid.
RuntimeError: if this method is called in eager mode.
Returns:
The GraphDef after the calibration.
"""
assert self._converted
assert self._need_calibration
assert not self._calibration_data_collected
if (feed_dict_fn and input_map_fn) or (not feed_dict_fn and
not input_map_fn):
raise ValueError(
"Should specify one and only one of feed_dict_fn and input_map_fn.")
if input_map_fn:
for k, v in input_map_fn().items():
if not isinstance(k, str):
raise ValueError("Keys of input_map_fn must be of type str")
if not isinstance(v, ops.Tensor):
raise ValueError("Values of input_map_fn must be of type tf.Tensor")
self._calibration_graph = ops.Graph()
with self._calibration_graph.as_default():
fetches = importer.import_graph_def(
self._converted_graph_def,
input_map=input_map_fn() if input_map_fn else None,
return_elements=fetch_names,
name="")
calibrate_rewriter_cfg = rewriter_config_pb2.RewriterConfig()
if self._test_only_disable_non_trt_optimizers:
trt_utils.disable_non_trt_optimizers_in_rewriter_config(
calibrate_rewriter_cfg)
# Set allow_soft_placement=True to run the graph for calibration so that
# OPs supported by TensorRT but don't have a GPU implementation are allowed
# to execute on CPU.
calibrate_config = config_pb2.ConfigProto(
allow_soft_placement=True,
graph_options=config_pb2.GraphOptions(
rewrite_options=calibrate_rewriter_cfg))
with session.Session(
graph=self._calibration_graph,
config=calibrate_config) as calibration_sess:
for _ in range(num_runs):
calibration_sess.run(
fetches, feed_dict=feed_dict_fn() if feed_dict_fn else None)
# Maps device name to the corresponding get_calibration_data.
#
# TODO(laigd): a better way would be to use calibration_sess to list
# all the devices, add one get_calibration_data for each device, and
# fetch each such op for every resource until its found. This can work
# even when the device of the TRTEngineOp is empty or not fully specified.
device_to_get_resource_op_map = {}
with self._calibration_graph.as_default():
resource_name_input = array_ops.placeholder(dtypes.string)
for node in self._converted_graph_def.node:
if node.op == _TRT_ENGINE_OP_NAME:
# Adds the get_calibration_data op for the device if not done
# before. We only add one such op for each device.
# TODO(laigd): What if the device is empty?????
if node.device not in device_to_get_resource_op_map:
with self._calibration_graph.device(node.device):
serialized_resources_output = (
gen_trt_ops.get_calibration_data_op(resource_name_input))
device_to_get_resource_op_map[node.device] = (
serialized_resources_output)
# Get the calibration resource.
calibration_result = calibration_sess.run(
device_to_get_resource_op_map[node.device],
feed_dict={
resource_name_input: _get_canonical_engine_name(node.name)
})
node.attr["calibration_data"].s = calibration_result
self._calibration_data_collected = True
return self._converted_graph_def
def save(self, output_saved_model_dir):
"""Save the converted graph as a SavedModel.
Args:
output_saved_model_dir: construct a SavedModel using the converted
GraphDef and save it to the specified directory. This option only works
when the input graph is loaded from a SavedModel, i.e. when
input_saved_model_dir is specified and input_graph_def is None in
__init__().
Raises:
ValueError: if the input to the converter is a GraphDef instead of a
SavedModel.
"""
assert self._converted
if self._need_calibration:
assert self._calibration_data_collected
if self._input_graph_def:
raise ValueError(
"Not able to save to a SavedModel since input is a GraphDef")
def _restore_collections(dest_graph, src_meta_graph_def, collection_keys):
"""Restores collections that we need to keep."""
scope = ""
for key in collection_keys:
collection_def = src_meta_graph_def.collection_def[key]
kind = collection_def.WhichOneof("kind")
if kind is None:
logging.error(
"Cannot identify data type for collection %s. Skipping.", key)
continue
from_proto = ops.get_from_proto_function(key)
if from_proto and kind == "bytes_list":
proto_type = ops.get_collection_proto_type(key)
# It is assumed that there are no Variables Keys in collections
for value in collection_def.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
try:
new_value = from_proto(proto, import_scope=scope)
except:
continue
dest_graph.add_to_collection(key, new_value)
else:
field = getattr(collection_def, kind)
if kind == "node_list":
for value in field.value:
name = ops.prepend_name_scope(value, scope)
# Since the graph has been optimized, the node may no longer
# exists
try:
col_op = dest_graph.as_graph_element(name)
except (TypeError, ValueError, KeyError):
continue
dest_graph.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the
# fact that Python2 distinguishes between int and long, while
# Python3 has only int.
for value in field.value:
dest_graph.add_to_collection(key, int(value))
else:
for value in field.value:
dest_graph.add_to_collection(key,
ops.prepend_name_scope(value, scope))
# Write the transformed graphdef as SavedModel.
saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
with ops.Graph().as_default():
importer.import_graph_def(self._converted_graph_def, name="")
_restore_collections(
ops.get_default_graph(), self._grappler_meta_graph_def,
self._collections_to_keep(
self._grappler_meta_graph_def.collection_def))
# We don't use any specific converter here.
with session.Session() as sess:
saved_model_builder.add_meta_graph_and_variables(
sess,
self._input_saved_model_tags,
signature_def_map=self._grappler_meta_graph_def.signature_def)
# Ignore other meta graphs from the input SavedModel.
saved_model_builder.save()
def _get_resource_handle(name, device):
with ops.device(device):
return gen_trt_ops.create_trt_resource_handle(resource_name=name)
def _remove_native_segments(input_func):
"""Remove native segments from the input TF-TRT Converted Function.
Args:
input_func: provide the concrete function with native segment nodes. The
transformed output func will not contain any native segment nodes. All the
TRTEngineOp references will be deleted and reset to default empty func.
"""
input_graph_def = input_func.graph.as_graph_def()
# Deleting the Native Segment node in each TRTEngineOp node.
nodes_deleted = 0
for func_id in reversed(range(len(input_graph_def.library.function))):
f = input_graph_def.library.function[func_id]
if "native_segment" in f.signature.name:
nodes_deleted += 1
while context.context().has_function(f.signature.name):
context.context().remove_function(f.signature.name)
del input_graph_def.library.function[func_id]
logging.info(
"Found and deleted native segments from "
f"{nodes_deleted} TRTEngineOp nodes."
)
# Deleting the references to `<EngineName>_native_segment`s.
# This helps TRTEngineOp constructor to not look for native segment handles
# during construction of graph for inference.
for node in input_graph_def.node:
if node.op == "TRTEngineOp":
del node.attr["segment_func"]
for func in input_graph_def.library.function:
for node in func.node_def:
if node.op == "TRTEngineOp":
del node.attr["segment_func"]
# Reconstruct the converted_func with the new graph
new_func = _construct_function_from_graph_def(input_func, input_graph_def)
return new_func
class _TRTEngineResource(resource.TrackableResource):
"""Class to track the serialized engines resource."""
def __init__(self,
resource_name,
filename,
maximum_cached_engines,
device="GPU"):
super(_TRTEngineResource, self).__init__(device=device)
self._resource_name = resource_name
# Track the serialized engine file in the SavedModel.
self._filename = self._track_trackable(
asset.Asset(filename), "_serialized_trt_resource_filename")
self._maximum_cached_engines = maximum_cached_engines
def _create_resource(self):
return _get_resource_handle(self._resource_name, self._resource_device)
def _initialize(self):
gen_trt_ops.initialize_trt_resource(
self.resource_handle,
self._filename,
max_cached_engines_count=self._maximum_cached_engines)
def _destroy_resource(self):
handle = _get_resource_handle(self._resource_name, self._resource_device)
with ops.device(self._resource_device):
gen_resource_variable_ops.destroy_resource_op(
handle, ignore_lookup_error=True)
def _print_row(fields, positions, print_fn):
"""Prints a row."""
line = ""
for i, field in enumerate(fields):
field = str(field)
end_line_pos = positions[i]
if i > 0:
line = line + " "
line = "{0:{min_length}}".format(line + field, min_length=end_line_pos)
if len(line) > end_line_pos:
line = line[:(end_line_pos - 4)] + " ..."
print_fn(line)
def _construct_function_from_graph_def(func, graph_def, frozen_func=None):
"""Rebuild function from graph_def."""
if frozen_func is None:
frozen_func = func
# If a function is converted, then the TF context contains the original
# function while the converted_graph_def contains the converted function.
# Remove the original function from the TF context in this case.
for f in graph_def.library.function:
while context.context().has_function(f.signature.name):
context.context().remove_function(f.signature.name)
captures = {
c.internal.name.split(":")[0]: c.external
for c in frozen_func.graph._function_captures.by_val_captures.values() # pylint: disable = protected-access
}
new_func = wrap_function.function_from_graph_def(
graph_def, [tensor.name for tensor in frozen_func.inputs],
[tensor.name for tensor in frozen_func.outputs], captures)
new_func.graph.structured_outputs = nest.pack_sequence_as(
func.graph.structured_outputs, new_func.graph.structured_outputs)
# Copy structured input signature from original function (used during
# serialization)
new_func.graph.structured_input_signature = (func.structured_input_signature)
return new_func
def _apply_inlining(func):
"""Apply an inlining optimization to the function's graph definition."""
graph_def = func.graph.as_graph_def()
# In some cases, a secondary implementation of the function (e.g. for GPU) is
# written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
# TF2 produces a CuDNN-based RNN for GPU).
# This function suppose to inline all functions calls, but "api_implements"
# prevents this from happening. Removing the attribute solves the problem.
# To learn more about "api_implements", see:
# tensorflow/core/grappler/optimizers/implementation_selector.h
for function in graph_def.library.function:
if "api_implements" in function.attr:
del function.attr["api_implements"]
meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph)
# Clear the initializer_name for the variables collections, since they are not
# needed after saved to saved_model.
for name in [
"variables", "model_variables", "trainable_variables", "local_variables"
]:
raw_list = []
for raw in meta_graph.collection_def["variables"].bytes_list.value:
variable = variable_pb2.VariableDef()
variable.ParseFromString(raw)
variable.ClearField("initializer_name")
raw_list.append(variable.SerializeToString())
meta_graph.collection_def[name].bytes_list.value[:] = raw_list
# Add a collection 'train_op' so that Grappler knows the outputs.
fetch_collection = meta_graph_pb2.CollectionDef()
for array in func.inputs + func.outputs:
fetch_collection.node_list.value.append(array.name)
meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)