/
trt_engine_op.cc
1448 lines (1321 loc) · 60.7 KB
/
trt_engine_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <memory>
#include <vector>
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/common/utils.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/common_runtime/graph_optimizer.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/grappler/clusters/utils.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/util/env_var.h"
#if GOOGLE_CUDA && GOOGLE_TENSORRT
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/tensorrt/NvInfer.h"
namespace tensorflow {
namespace tensorrt {
namespace {
Logger& logger = *Logger::GetLogger();
using absl::StrAppend;
using absl::StrCat;
using ::nvinfer1::IRuntime;
#define LOG_FIRST_FEW_WARNING_WITH_PREFIX \
LOG_FIRST_N(WARNING, 5) << "TF-TRT Warning: "
// Allocates device memory for an execution context to execute a TensorRT
// engine and records the relevant information for deallocating the memory when
// the engine finishes execution.
class ContextDeviceMemory {
public:
ContextDeviceMemory()
: execution_context_(nullptr),
device_memory_allocator_(nullptr),
device_memory_(nullptr) {}
~ContextDeviceMemory() {
if (device_memory_) {
device_memory_allocator_->free(device_memory_);
}
}
Status AllocateDeviceMemory(nvinfer1::IExecutionContext* execution_context,
TRTBaseAllocator* device_memory_allocator,
size_t device_memory_size) {
execution_context_ = execution_context;
device_memory_allocator_ = device_memory_allocator;
device_memory_ = nullptr;
VLOG(2) << "Device memory size for TensorRT engine " << device_memory_size;
if (device_memory_size > 0) {
device_memory_ = device_memory_allocator_->allocate(
device_memory_size,
/*unused alignment=*/0, /*flags=*/0);
if (device_memory_ == nullptr) {
return errors::InvalidArgument(
"Out of GPU memory for execution context");
}
}
{
tensorflow::profiler::TraceMe activity(
"setDeviceMemory", tensorflow::profiler::TraceMeLevel::kInfo);
execution_context_->setDeviceMemory(device_memory_);
}
return OkStatus();
}
private:
nvinfer1::IExecutionContext* execution_context_;
TRTBaseAllocator* device_memory_allocator_;
void* device_memory_;
};
// Macros for asynchronous execution, such as OP_REQUIRES_OK_ASYNC requires an
// object with operator (). Provides such an object with a noop operator()
// because we don't need such macros to invoke the DoneCallback for the
// TRTEngineOp.
struct DummyAsyncHelper {
void operator()() {}
};
// A helper class to call the DoneCallback for the TRTEngineOp when the object
// is destructed to support asynchronous of the native segment and TRT engines
// for the TRTEngineOp.
class AsyncHelper : public core::RefCounted {
public:
AsyncHelper(AsyncOpKernel::DoneCallback done) : done_(done) {}
~AsyncHelper() override { done_(); }
private:
AsyncOpKernel::DoneCallback done_;
};
} // end anonymous namespace
// This OP can construct TRTEngine on the fly and if construction of engine
// fails, executes equivalent subgraph as a TensorFlow function.
class TRTEngineOp : public AsyncOpKernel {
public:
explicit TRTEngineOp(OpKernelConstruction* context);
void ComputeAsync(OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override;
private:
// Executes calibration asynchronously.
void ExecuteCalibration(OpKernelContext* ctx,
TRTEngineCacheResource* cache_res,
AsyncHelper* async_helper);
// Constructs a function handle for the segment of the TRTEngineOp.
StatusOr<FunctionLibraryRuntime::Handle> ConstructFunctionHandle(
FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement = false, size_t num_inputs = 0,
size_t num_outputs = 0);
// Imports the GraphDef for the segment of the TRTEngineOp to
// segment_graph_def_.
Status ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
const string& device_name);
// Executes the native segment as function Op asynchronously.
void ExecuteNativeSegment(OpKernelContext* ctx, AsyncHelper* async_helper);
// Allocates the device memory for the execution context and enqueues the
// TensorRT engine for execution. Also deallocates the device memory. Returns
// whether we need to retry by running the native segment.
Status ExecuteTrtEngine(OpKernelContext* ctx, EngineContext* engine_context,
int trt_context_idx,
const TrtShapeOptimizationProfile& profiles,
TRTBaseAllocator* allocator);
// Allocates necessary resources for calibration.
Status AllocateCalibrationResources(OpKernelContext* ctx,
TRTEngineCacheResource* cache_res);
Status GetEngineCacheResource(OpKernelContext* ctx,
TRTEngineCacheResource** cache_res);
// Returns a pair of 1) An EngineContext object that is compatible with the
// input and 2) The index of the IExecutionContext compatible with the input.
// If a cuda engine for the given input shapes can't be found, returns
// (nullptr, 0) to allow native engine execution. Returns an error code for
// any problem that would prevent both TensorRT engine exceution and native
// segment execution.
StatusOr<std::pair<EngineContext*, int>> GetEngine(
const std::vector<TensorShape>& input_concrete_shapes,
OpKernelContext* ctx, TRTEngineCacheResource* cache_resource);
// Builds and returns a cuda engine for the input shapes. If building the
// engine fails, enters a dummy entry into the cache_resource cache so we
// don't continually try to build the same failing engine.
StatusOr<TrtUniquePtrType<nvinfer1::ICudaEngine>> BuildEngine(
const std::vector<TensorShape>& input_concrete_shapes, int batch_size,
bool use_calibration, TRTInt8Calibrator* calibrator,
TRTEngineCacheResource* cache_resource, OpKernelContext* ctx);
// Verify that the input shapes are consistent and can be handled by this op.
Status VerifyInputShapes(const std::vector<TensorShape>& shapes);
std::vector<string> input_nodes_;
std::vector<string> output_nodes_;
// serialized protobuf segment or trt engine depending on static_engine_ flag.
string serialized_segment_;
// The function for TF native execution of the segment.
NameAttrList func_;
// GraphDef representation of the segment.
GraphDef segment_graph_def_;
// Engine Precision mode.
TrtPrecisionMode precision_mode_;
// Whether engine is constructed during the conversion or needs to be
// constructed from protobuf segment.
bool static_engine_;
// Whether to calibrate INT8 engine.
bool calibration_mode_;
// Whether to use implicit batch dimension for TensorRT.
bool use_implicit_batch_;
// Whether to collect optimization profiles for TensorRT, only used when
// use_implicit_batch_=false.
bool profile_generation_mode_;
// Optimization profile generation strategy.
ProfileStrategy profile_strategy_;
// Whether the TRTEngineOp has any input with unknown dimensions.
bool has_dynamic_shape_input_;
// Whether to build TensorRT engines at runtime.
bool allow_build_at_runtime_;
// Whether to allow soft placement when the graph is executed with native
// TensorFlow.
bool allow_soft_placement_;
// Maximum number of cached engines.
int max_cached_engines_;
// Flag to detect whether native segment nodes have been deleted from graph
bool native_segment_absent_;
int64 workspace_size_;
mutex engine_mutex_;
FunctionLibraryRuntime::Handle native_execution_func_handle_;
// The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
// If true, create calibration graph for INT8 mode. Otherwise, we are using
// user-provided quantization ranges.
bool use_calibration_;
tensorflow::grappler::Cluster* cluster_;
// Array of all input shapes, collected from the input_shapes attribute when
// constructing the TRTEngineOp. The input_shapes attribute is set during
// graph conversion time. This data is used to retrieve which input dimensions
// could be unknown. During inference time this information is not available
// otherwise (all shapes are known (concrete) shapes when we run inference).
std::vector<PartialTensorShape> input_partial_shapes_;
// Shapes, excluding resource inputs.
std::vector<PartialTensorShape> input_partial_shapes_filtered_;
// The TF node can have more inputs than the TRT engine: resource inputs are
// saved as weight in the engine, instead of passing that as engine input.
// Input mask is true for those TF input that are TRT engine inputs.
std::vector<bool> input_mask_;
// Whether to use explicit precision (QDQ) mode.
bool use_explicit_precision_;
};
#define TYPECASE(dt, X) \
case dt: { \
return (void*)X->flat<EnumToDataType<dt>::Type>().data(); \
}
void* GetTensorAddress(const Tensor* tensor_ptr) {
const auto tensor_type = tensor_ptr->dtype();
switch (tensor_type) {
TYPECASE(DT_FLOAT, tensor_ptr);
TYPECASE(DT_HALF, tensor_ptr);
TYPECASE(DT_INT8, tensor_ptr);
TYPECASE(DT_INT32, tensor_ptr);
#if IS_TRT_VERSION_GE(8, 2, 0, 0)
TYPECASE(DT_BOOL, tensor_ptr);
#endif
#if IS_TRT_VERSION_GE(8, 5, 0, 0)
TYPECASE(DT_UINT8, tensor_ptr);
#endif
default: {
LOG(ERROR) << "Unsupported Data type " << DataTypeString(tensor_type);
return nullptr;
}
}
}
static Status FunctionDefToGraphDef(FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime* flib_runtime,
GraphDef* graph_def) {
const FunctionLibraryDefinition* flib_def =
flib_runtime->GetFunctionLibraryDefinition();
const FunctionBody* fbody;
fbody = flib_runtime->GetFunctionBody(handle);
if (!fbody) {
return errors::Internal(
"Function body is null when converting from FuncDef to GraphDef.");
}
std::unique_ptr<Graph> graph(new Graph(flib_def));
CopyGraph(*fbody->graph, graph.get());
auto replace_name = [](const char* const prefix, string* name) {
if (absl::StartsWith(*name, absl::AsciiStrToLower(prefix))) {
name->replace(0, strlen(prefix), prefix);
return true;
}
return false;
};
graph->ToGraphDef(graph_def);
// GraphToFunctionDef() will convert all the node names to lowercase.
for (auto& node : *graph_def->mutable_node()) {
if (!replace_name(IONamePrefixes::kInputPHName, node.mutable_name())) {
if (replace_name(IONamePrefixes::kOutputPHName, node.mutable_name())) {
// Instantiation of the function will append _RetVal to the node name,
// need to remove it for backward compatibility.
const char* const suffix_to_remove = "_RetVal";
if (absl::EndsWith(node.name(), suffix_to_remove)) {
node.mutable_name()->erase(node.name().size() -
strlen(suffix_to_remove));
}
}
}
for (auto& input : *node.mutable_input()) {
if (!replace_name(IONamePrefixes::kInputPHName, &input)) {
replace_name(IONamePrefixes::kOutputPHName, &input);
}
}
}
return OkStatus();
}
StatusOr<FunctionLibraryRuntime::Handle> TRTEngineOp::ConstructFunctionHandle(
FunctionLibraryRuntime* lib, const string& device_name,
bool allow_soft_placement, size_t num_inputs, size_t num_outputs) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ConstructFunctionHandle",
tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Constructing function handle";
if (lib == nullptr) {
return errors::Internal("Context function library is null");
}
FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.state_handle = "";
inst_ops.target = device_name;
if (!native_segment_absent_ && allow_soft_placement) {
const FunctionDef* fdef =
lib->GetFunctionLibraryDefinition()->Find(func_.name());
if (!fdef) {
return errors::Internal(
StrCat("Can't find FunctionDef for ", func_.name()));
}
bool ints_on_device =
fdef->attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 &&
fdef->attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b();
// kIntsOnDeviceAttr is not compatible with is_multi_device_function which
// is needed to support allow_soft_placement.
if (ints_on_device) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX
<< "Function " << name()
<< " has attribute kIntsOnDeviceAttr=true "
"and will be executed natively with allow_soft_placement=false. "
"If this is a problem, please re-generate your SavedModel with "
"the TF-TRT runtime you are using.";
} else {
inst_ops.is_multi_device_function = true;
inst_ops.input_devices.resize(num_inputs, device_name);
inst_ops.output_devices.resize(num_outputs, device_name);
inst_ops.config_proto.set_allow_soft_placement(true);
}
}
FunctionLibraryRuntime::Handle func_handle;
Status status = lib->Instantiate(func_.name(), AttrSlice(&func_.attr()),
inst_ops, &func_handle);
if (status.ok()) {
return func_handle;
}
return status;
}
Status TRTEngineOp::ImportSegmentGraphDef(FunctionLibraryRuntime* lib,
const string& device_name) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ImportSegmentGraphDef",
tensorflow::profiler::TraceMeLevel::kInfo);
TF_ASSIGN_OR_RETURN(FunctionLibraryRuntime::Handle func_handle,
ConstructFunctionHandle(lib, device_name));
return FunctionDefToGraphDef(func_handle, lib, &segment_graph_def_);
}
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::TRTEngineOp", tensorflow::profiler::TraceMeLevel::kInfo);
// read serialized_engine
OP_REQUIRES_OK(context,
context->GetAttr("serialized_segment", &serialized_segment_));
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
VLOG(1) << "Constructing " << name();
string precision_string;
OP_REQUIRES_OK(context,
context->GetAttr("precision_mode", &precision_string));
string calibration_data;
OP_REQUIRES_OK(context,
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context, context->GetAttr("segment_func", &func_));
OP_REQUIRES_OK(context,
TrtPrecisionModeFromName(precision_string, &precision_mode_));
OP_REQUIRES_OK(context,
context->GetAttr("use_calibration", &use_calibration_));
OP_REQUIRES_OK(context,
context->GetAttr("input_shapes", &input_partial_shapes_));
auto status =
context->GetAttr("_allow_build_at_runtime", &allow_build_at_runtime_);
if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(2) << "Not found _allow_build_at_runtime in "
<< context->device()->name()
<< ", thus setting _allow_build_at_runtime=true";
allow_build_at_runtime_ = true;
} else {
OP_REQUIRES_OK(context, status);
}
// Get a mask of non-resource inputs.
std::vector<DataType> in_types;
input_mask_.resize(input_partial_shapes_.size());
OP_REQUIRES_OK(context, context->GetAttr("InT", &in_types));
for (int i = 0; i < input_mask_.size(); i++) {
input_mask_[i] = (in_types[i] != DataType::DT_RESOURCE);
}
// Filter the shapes to exclude resources.
for (int i = 0; i < input_partial_shapes_.size(); i++) {
if (input_mask_[i]) {
input_partial_shapes_filtered_.push_back(input_partial_shapes_[i]);
}
}
status = context->GetAttr("_allow_soft_placement", &allow_soft_placement_);
if (status.code() == tensorflow::error::NOT_FOUND) {
allow_soft_placement_ = true;
} else {
OP_REQUIRES_OK(context, status);
}
status = context->GetAttr("use_explicit_precision", &use_explicit_precision_);
if (!status.ok()) {
use_explicit_precision_ = false;
}
// When a TF-TRT converted model without native segments is loaded,
// func_ can be empty.
native_segment_absent_ = (func_.name() == "");
native_execution_func_handle_ = kInvalidHandle;
if (!native_segment_absent_) {
if (!static_engine_) {
OP_REQUIRES_OK(context, ImportSegmentGraphDef(context->function_library(),
context->device()->name()));
}
}
// TODO(laigd): calibration_data is used in TF v1.x and we keep it only for
// backward compatibility reasons. Remove it once all known users switch to
// 2.0.
calibration_mode_ =
(use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
calibration_data.empty());
if (!calibration_data.empty()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
}
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
status = context->GetAttr("_use_implicit_batch", &use_implicit_batch_);
if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(2) << "Not found _use_implicit_batch in " << context->device()->name()
<< ", thus setting _use_implicit_batch=true";
use_implicit_batch_ = true;
}
status =
context->GetAttr("_profile_generation_mode", &profile_generation_mode_);
if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(2) << "Not found _profile_generation_mode in "
<< context->device()->name()
<< ", thus setting _profile_generation_mode=false";
profile_generation_mode_ = false;
}
if (static_engine_) {
if (profile_generation_mode_) profile_generation_mode_ = false;
}
if (use_implicit_batch_) {
OP_REQUIRES(context, !profile_generation_mode_,
errors::InvalidArgument(
"profile_generation_mode_=true is only supported if "
"use_implicit_batch=false"));
if (input_partial_shapes_.empty()) {
VLOG(1) << "Attribute input_shapes is not set. This happens probably "
<< "because you are using a model that is already converted "
<< "to TensorRT with a previous version of TF-TRT (i.e. includes "
<< "TRTEngineOp in graph). This is not an error. If you convert "
<< "the original model again to TensorRT, the attributes "
<< "input_shapes will be set automatically.";
}
} else {
OP_REQUIRES(
context, !input_partial_shapes_.empty(),
errors::InvalidArgument(
"Explicit batch mode requires attribute input_shapes to be set."
"If you are using a model that was converted to TensorRT by a "
"previous version of TF-TRT, (i.e. includes TRTEngineOp in graph "
"without the input_shapes attribute), then you need to convert the "
"original model again to TensorRT in order to set the attribute "
"input_shapes."));
string profile_strategy_name;
status = context->GetAttr("profile_strategy", &profile_strategy_name);
if (status.code() == tensorflow::error::NOT_FOUND) {
VLOG(2) << "Not found strategy in " << context->device()->name()
<< ", thus setting profile_strategy='Range'";
profile_strategy_ = ProfileStrategy::kRange;
} else {
OP_REQUIRES_OK(context, ProfileStrategyFromName(profile_strategy_name,
&profile_strategy_));
}
}
has_dynamic_shape_input_ = absl::c_any_of(
input_partial_shapes_filtered_,
[](PartialTensorShape shape) { return !shape.IsFullyDefined(); });
VLOG(2) << "TRTEngineOp has_dynamic_shape_input_: "
<< has_dynamic_shape_input_;
}
// Copies input tensor ctx->input(i) (which is in device memory) to the host,
// and place the resulting host tensor to the back of native_inputs.
Status CopyToHostAsync(OpKernelContext* ctx, std::vector<Tensor>* native_inputs,
int i, const cudaStream_t stream) {
// The TRTEngineOp has all ctx->inputs on the device. In contrast, the
// native segment expects to find int32 inputs on the host. We copy int32
// inputs from device to host.
AllocatorAttributes allocator_attr;
allocator_attr.set_on_host(true);
Tensor t;
TF_RETURN_IF_ERROR(ctx->allocate_temp(
ctx->input_dtype(i), ctx->input(i).shape(), &t, allocator_attr));
native_inputs->push_back(t);
const Tensor& gpu_tensor = ctx->input(i);
auto ret = cudaMemcpyAsync(
t.flat<int32>().data(), gpu_tensor.flat<int32>().data(),
t.NumElements() * sizeof(int32), cudaMemcpyDeviceToHost, stream);
if (ret != 0) {
return errors::Internal("Could not copy tensor for native segment input");
}
return OkStatus();
}
// Copies native_tensor, which is in host memory to ctx->output(t), which is in
// device memory.
Status CopyToDeviceAsync(OpKernelContext* ctx, const Tensor& native_tensor,
int t, cudaStream_t stream) {
Tensor* gpu_tensor;
TF_RETURN_IF_ERROR(
ctx->allocate_output(t, native_tensor.shape(), &gpu_tensor));
auto ret = cudaMemcpyAsync(gpu_tensor->flat<int32>().data(),
native_tensor.flat<int32>().data(),
native_tensor.NumElements() * sizeof(int32),
cudaMemcpyHostToDevice, stream);
if (ret != 0) {
return errors::Internal("Could not copy tensor for native segment output");
}
return OkStatus();
}
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* async_helper) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ExecuteNativeSegment",
tensorflow::profiler::TraceMeLevel::kInfo);
std::vector<Tensor> native_inputs;
std::vector<Tensor>* native_outputs = new std::vector<Tensor>();
DummyAsyncHelper dummy_async_helper;
if (native_execution_func_handle_ == kInvalidHandle) {
StatusOr<FunctionLibraryRuntime::Handle> status_or_handle =
ConstructFunctionHandle(ctx->function_library(), ctx->device()->name(),
allow_soft_placement_, ctx->num_inputs(),
ctx->num_outputs());
OP_REQUIRES_OK_ASYNC(ctx, status_or_handle.status(), dummy_async_helper);
native_execution_func_handle_ = *status_or_handle;
}
auto lib = ctx->function_library();
FunctionLibraryRuntime::Options opts;
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
native_inputs.reserve(ctx->num_inputs());
int n_copies = 0;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(CHECK_NOTNULL(
ctx->op_device_context()->stream()->platform_specific_handle().stream));
for (int i = 0; i < ctx->num_inputs(); i++) {
if (ctx->input_dtype(i) != DT_INT32) {
native_inputs.push_back(ctx->input(i));
} else {
OP_REQUIRES_OK_ASYNC(ctx, CopyToHostAsync(ctx, &native_inputs, i, stream),
dummy_async_helper);
n_copies++;
}
}
if (n_copies > 0) {
// If we have any int32 tensors, then wait until data is copied to host.
cudaStreamSynchronize(stream);
}
VLOG(1) << "Executing native segment: " << name();
// Increment the reference count of the async_helper by 1. When the native
// segment finishes execution asynchronously, we decrement the reference
// count of the object.
async_helper->Ref();
lib->Run(
opts, native_execution_func_handle_, native_inputs, native_outputs,
[this, ctx, native_outputs, async_helper, stream](const Status& s) {
core::ScopedUnref sc(async_helper);
DummyAsyncHelper dummy_async_helper;
std::unique_ptr<std::vector<Tensor>> outputs_wrapper(native_outputs);
OP_REQUIRES_OK_ASYNC(ctx, s, dummy_async_helper);
VLOG(1) << "Native Segment completed";
int n_copies = 0;
for (size_t t = 0; t < native_outputs->size(); ++t) {
if (native_outputs->at(t).dtype() == DT_INT32) {
OP_REQUIRES_OK_ASYNC(
ctx, CopyToDeviceAsync(ctx, native_outputs->at(t), t, stream),
dummy_async_helper);
n_copies++;
} else {
ctx->set_output(t, native_outputs->at(t));
}
}
if (n_copies > 0) {
cudaStreamSynchronize(stream);
}
});
}
void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
TRTEngineCacheResource* cache_res,
AsyncHelper* async_helper) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ExecuteCalibration",
tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Executing TRT calibration: " << name();
DummyAsyncHelper dummy_async_helper;
CalibrationContext* calib_ctx = cache_res->calib_ctx_.get();
const int num_inputs = ctx->num_inputs();
// TODO(laigd): need to check that input shape matches.
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
bool input_size_ok = true;
for (int i = 0; i < num_inputs; i++) {
const Tensor& t = ctx->input(i);
void* data_address = GetTensorAddress(&t);
OP_REQUIRES_ASYNC(ctx, data_address,
errors::InvalidArgument(
"Unsupported data type encountered in input ", i),
dummy_async_helper);
// Check the allocated buffer is sufficient for input
const auto device_tensor = &calib_ctx->device_tensors_.at(i);
if (t.TotalBytes() != device_tensor->TotalBytes()) {
// This can happen if the network has data dependent shapes.
input_size_ok = false;
VLOG(2) << "Size differs for input " << i
<< ", skipping calibration for this input.";
break;
}
input_data.emplace(StrCat(IONamePrefixes::kInputPHName, i), data_address);
}
if (input_size_ok) {
VLOG(2) << "Filled map for sending";
// Copied from gpu_kernel_helper.h as the header can only be used in *.cu.cc
// files.
cudaStream_t stream = reinterpret_cast<cudaStream_t>(CHECK_NOTNULL(
ctx->op_device_context()->stream()->platform_specific_handle().stream));
// TRTInt8Calibrator::setBatch will wait until TRTInt8Calibrator::getBatch
// is called before proceeding with feeding the calibration data to the
// calibrator. It returns true if the calibration data is accepted and
// returns false if calibration is terminated due to errors.
//
// If TRTInt8Calibrator::getBatch is never called, which could happen if
// there is any problem in building the cuda engine for calibration inside
// TensorRT, then the TRTInt8Calibrator::setBatch call here will hang until
// TRTInt8Calibrator::setDone is called by the calibration thread in
// AllocateCalibrationResources.
//
// In both of the above cases, setBatch here returns a boolean value to
// indicate the result of the calibration process.
if (!calib_ctx->calibrator_->setBatch(input_data, stream)) {
VLOG(2) << "Failed to feed calibration data";
} else {
VLOG(2) << "Passed calibration data";
}
}
if (!native_segment_absent_) {
ExecuteNativeSegment(ctx, async_helper);
} else {
LOG(ERROR) << "Calibration requires native segment, but is not found in "
"the graph.";
}
}
Status TRTEngineOp::VerifyInputShapes(
const std::vector<TensorShape>& input_concrete_shapes) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::VerifyInputShapes",
tensorflow::profiler::TraceMeLevel::kInfo);
if (input_concrete_shapes.empty()) {
return errors::InvalidArgument("Input shapes are empty, for ", name());
}
if (input_partial_shapes_filtered_.empty()) {
if (!use_implicit_batch_) {
return errors::InvalidArgument(
"Explicit batch mode requires input_partial_shapes_ ",
"to contain the dynamic input shapes to TRTEngineOp");
}
// If the graph was converted with an earlier version of TF-TRT, it can
// happen that the input_partial_shapes_ vector is not set (see
// input_shapes attribute handling in the TRTEngineOp constructor).
// In implicit batch mode it is allowed to have empty input_partial_shapes_,
// since it is only required in explicit batch mode (see the input_shapes
// attribute of ConvertGraphDefToEngine in TRTEngineOp::GetEngine.
} else {
// Additional consistency checks if input_partial_shapes_ is present.
const string error_msg = StrCat(
"Input shapes do not match input partial shapes stored in graph, for ",
name(), ": ", DebugString(input_concrete_shapes),
" != ", DebugString(input_partial_shapes_filtered_));
if (input_concrete_shapes.size() != input_partial_shapes_filtered_.size()) {
return errors::InvalidArgument(error_msg);
}
for (int i = 0; i < input_concrete_shapes.size(); i++) {
if (input_concrete_shapes[i].dims() !=
input_partial_shapes_filtered_[i].dims()) {
return errors::InvalidArgument(error_msg);
}
}
for (int i = 0; i < input_concrete_shapes.size(); i++) {
for (int d = 0; d < input_concrete_shapes[i].dims(); d++) {
if (input_partial_shapes_filtered_[i].dim_size(d) != -1) {
if (input_concrete_shapes[i].dim_size(d) !=
input_partial_shapes_filtered_[i].dim_size(d)) {
return errors::InvalidArgument(error_msg);
}
}
}
}
}
if (use_implicit_batch_) {
if (input_concrete_shapes[0].dims() < 1) {
return errors::InvalidArgument(
"Input shapes contain scalar, for ", name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
const int batch_size = input_concrete_shapes[0].dim_size(0);
if (batch_size < 1) {
return errors::InvalidArgument(
"Incorrect batch dimension, for ", name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
for (const TensorShape& shape : input_concrete_shapes) {
if (batch_size != shape.dim_size(0)) {
return errors::InvalidArgument(
"Input shapes are inconsistent on the batch dimension, for ",
name(), ": ",
TensorShapeUtils::ShapeListString(input_concrete_shapes));
}
}
}
return OkStatus();
}
static bool AllowEngineNativeSegmentExecution() {
bool value;
Status status =
ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION",
/*default_val=*/true, &value);
if (!status.ok()) {
LOG(ERROR) << status;
}
return value;
}
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
AsyncOpKernel::DoneCallback done) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ComputeAsync", tensorflow::profiler::TraceMeLevel::kInfo);
// Invoke DoneCallback when this object is destructed, which could be after
// this routine finishes execution, in particular, when native segment is
// executed.
auto async_helper = new AsyncHelper(done);
core::ScopedUnref sc(async_helper);
// For all async execution macros, use this object as there is no need to call
// DoneCallback from those macros.
DummyAsyncHelper dummy_async_helper;
// Get TRT resource.
TRTEngineCacheResource* cache_res = nullptr;
OP_REQUIRES_OK_ASYNC(ctx, GetEngineCacheResource(ctx, &cache_res),
dummy_async_helper);
core::ScopedUnref unref_cache_res(cache_res);
// Get shapes of inputs to engine.
std::vector<TensorShape> input_concrete_shapes;
input_concrete_shapes.reserve(ctx->num_inputs());
std::vector<TensorShape> input_concrete_shapes_filtered;
for (int i = 0; i < ctx->num_inputs(); ++i) {
input_concrete_shapes.push_back(ctx->input(i).shape());
if (ctx->input(i).dtype() != DataType::DT_RESOURCE) {
input_concrete_shapes_filtered.push_back(ctx->input(i).shape());
}
}
/// TODO(lsugy): fix case of engine with only resource inputs.
Status verify_input_shape_status =
VerifyInputShapes(input_concrete_shapes_filtered);
// TODO(bixia): Fix the segmentation.
if (!verify_input_shape_status.ok() && !native_segment_absent_) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX
<< "Running native segment for" << name()
<< " due to failure in verifying input shapes: "
<< verify_input_shape_status.message();
ExecuteNativeSegment(ctx, async_helper);
return;
}
if (!use_implicit_batch_ &&
(has_dynamic_shape_input_ || cache_res->profiles_.HasShapeTensor())) {
OP_REQUIRES_OK_ASYNC(ctx, cache_res->profiles_.CollectShapeValues(ctx),
dummy_async_helper);
cache_res->profiles_.SetInputMask(input_mask_);
if (profile_generation_mode_) {
// Collecting new shapes for profiles can be only done once. After the
// shapes are converted to TRT profiles, no shapes can be collected
// anymore.
OP_REQUIRES_ASYNC(ctx, cache_res->profiles_.GetNumProfiles() == 0,
errors::Unimplemented("Cannot collect new shapes when "
"profiles are already created."),
dummy_async_helper);
// Just collect the input shape info and return. The shapes are used to
// generate optimization profiles during engine creation.
cache_res->profiles_.AddShape(input_concrete_shapes);
VLOG(1)
<< "Native segment is used during collecting shapes for profiles.";
if (!native_segment_absent_) {
ExecuteNativeSegment(ctx, async_helper);
} else {
LOG(ERROR) << "Native segment is required for profile generation, "
"but is not found in the graph.";
}
return;
} else if (cache_res->profiles_.GetNumProfiles() == 0 && !static_engine_) {
// Add current shape if we did not collect any shapes so far.
if (!cache_res->profiles_.HasShape()) {
cache_res->profiles_.AddShape(input_concrete_shapes);
}
// Create profiles out of collected shapes during profile generation.
cache_res->profiles_.InitProfiles(input_partial_shapes_,
profile_strategy_);
}
}
// Run calibration if in int8+calibration mode.
// * Logic in TF 1.x:
// - During conversion: calibration_mode_ is true and cache size is 0, so it
// will run calibration.
// - During inference: calibration_data will be set, so calibration_mode_
// is false and it won't trigger calibration.
// * Logic in TF 2.0:
// - During conversion: similar to 1.x.
// - During inference: calibration_data will still be empty, but cache will
// contain the calibrated engine, so it won't trigger calibration.
//
// TODO(laigd): consider the following alternatives:
// 1. Serialize the state (calibration or inference) using
// TRTEngineInstance proto (or a new proto), so we know which mode we're
// in and don't run calibration during inference (which is invalid).
// 2. Reuse the calibration_data attribute or use a new attribute in the
// NodeDef to indicate whether it's in calibration mode.
if (calibration_mode_ && cache_res->cache_.size() == 0) {
if (!cache_res->calib_ctx_) {
// TODO(laigd): better encapsulation.
mutex_lock lock(engine_mutex_);
if (!cache_res->calib_ctx_) {
// Add profiles if we are in dynamic shape mode.
if (!use_implicit_batch_ && (has_dynamic_shape_input_ ||
cache_res->profiles_.HasShapeTensor())) {
cache_res->profiles_.InitCalibProfile(input_concrete_shapes);
}
OP_REQUIRES_OK_ASYNC(ctx, AllocateCalibrationResources(ctx, cache_res),
dummy_async_helper);
}
}
// TODO(laigd): check that the input shapes match the shapes of the
// persistent tensor in the calibration resource.
ExecuteCalibration(ctx, cache_res, async_helper);
return;
}
StatusOr<std::pair<EngineContext*, int>> status =
GetEngine(input_concrete_shapes, ctx, cache_res);
OP_REQUIRES_OK_ASYNC(ctx, status.status(), dummy_async_helper);
EngineContext* engine_context = status.value().first;
int trt_context_idx = status.value().second;
auto may_execute_native_segment = [&] {
if (!native_segment_absent_ && !AllowEngineNativeSegmentExecution()) {
ctx->CtxFailure(
errors::Aborted("User disallowed engine native segment execution."));
return false;
} else if (native_segment_absent_) {
ctx->CtxFailure(
errors::Aborted("Native segment execution is enabled but "
" native segment is not found in the graph."));
return false;
}
return true;
};
if (!engine_context->GetCudaEngine()) {
LOG_FIRST_FEW_WARNING_WITH_PREFIX
<< "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_concrete_shapes)
<< " failed. Running native segment for " << name();
if (may_execute_native_segment()) {
ExecuteNativeSegment(ctx, async_helper);
}
return;
}
Status stat =
ExecuteTrtEngine(ctx, engine_context, trt_context_idx,
cache_res->profiles_, cache_res->allocator_.get());
if (stat.ok()) return;
LOG_FIRST_FEW_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat
<< " Retrying with native segment for "
<< name();
if (!may_execute_native_segment()) {
return;
}
// When Native Segment execution is enabled, release any outputs that
// are allocated. ExecuteNativeSegment will re-allocate them and
// fail if they are currently allocated.
// The Tensor pointer in the returned TensorValue must be explicitly
// deleted.
for (int i = 0; i < ctx->num_outputs(); i++) {
delete ctx->release_output(i).tensor;
}
if (!native_segment_absent_) {
ExecuteNativeSegment(ctx, async_helper);
} else {
LOG(ERROR) << "Native segment execution is enabled, "
"but native segment is not found in the graph.";
}
}
Status TRTEngineOp::ExecuteTrtEngine(
OpKernelContext* ctx, EngineContext* engine_context, int trt_context_idx,
const TrtShapeOptimizationProfile& profiles, TRTBaseAllocator* allocator) {
tensorflow::profiler::TraceMe activity(
"TRTEngineOp::ExecuteTrtEngine",
tensorflow::profiler::TraceMeLevel::kInfo);
VLOG(1) << "Executing TRT engine: " << name();
nvinfer1::ICudaEngine* cuda_engine = engine_context->GetCudaEngine();
if (VLOG_IS_ON(2)) {
VLOG(2) << " Network name: " << cuda_engine->getName();
VLOG(2) << " Activation size: " << engine_context->GetDeviceMemorySize()
<< " bytes";