-
Notifications
You must be signed in to change notification settings - Fork 74k
/
trt_engine_op.cc
651 lines (619 loc) · 26.7 KB
/
trt_engine_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.h"
#include <algorithm>
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/test_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resource_manager.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_resources.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
#if GOOGLE_CUDA
#if GOOGLE_TENSORRT
#include "cuda/include/cuda_runtime_api.h"
namespace tensorflow {
namespace tensorrt {
static Logger logger;
using absl::StrAppend;
using absl::StrCat;
using ::nvinfer1::IRuntime;
// A helper class to call done() when destructed for asynchronous execution.
// Helps simultaneous execution of native and TRT engines.
class AsyncHelper : public tensorflow::core::RefCounted {
public:
AsyncHelper(AsyncOpKernel::DoneCallback done) { done_ = done; }
~AsyncHelper() override { done_(); }
private:
AsyncOpKernel::DoneCallback done_;
};
#define TYPECASE(dt, X, Y) \
case dt: { \
return (void*)X->flat<tensorflow::EnumToDataType<dt>::Type>().data(); \
}
void* GetTensorAddress(const Tensor* tensor_ptr) {
auto tensor_type = tensor_ptr->dtype();
switch (tensor_type) {
TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr);
TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr);
default: {
LOG(ERROR) << "Unsupported Data type "
<< tensorflow::DataTypeString(tensor_type);
return nullptr;
}
}
}
tensorflow::Status TRTEngineOp::ConstructFunctionHandle(OpKernelContext* ctx) {
VLOG(1) << "Constructing function handle";
auto lib = ctx->function_library();
if (lib == nullptr) {
return tensorflow::errors::Internal("Context function library is null");
}
auto fdef = lib->GetFunctionLibraryDefinition()->Find(funcdef_name_);
if (fdef == nullptr) {
return tensorflow::errors::Internal("Native FunctionDef ", funcdef_name_,
" can't be found in function library");
}
tensorflow::FunctionLibraryRuntime::InstantiateOptions inst_ops;
inst_ops.overlay_lib = nullptr;
inst_ops.state_handle = "";
inst_ops.target = ctx->device()->name();
native_func_ = 0;
auto status = lib->Instantiate(funcdef_name_, AttrSlice(&fdef->attr()),
inst_ops, &native_func_);
if (!status.ok()) {
LOG(ERROR) << " Instantiating native function " << funcdef_name_
<< " failed!";
}
return status;
}
TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
: AsyncOpKernel(context) {
// read serialized_engine
OP_REQUIRES_OK(context,
context->GetAttr("serialized_segment", &serialized_segment_));
OP_REQUIRES_OK(context,
context->GetAttr("workspace_size_bytes", &workspace_size_));
OP_REQUIRES_OK(context, context->GetAttr("static_engine", &static_engine_));
if (!static_engine_) {
if (!segment_graph_.ParseFromString(serialized_segment_)) {
LOG(ERROR) << "Parsing segment graph failed!";
context->SetStatus(tensorflow::errors::InvalidArgument(
"Failed to parse segment graphdef!"));
return;
}
serialized_segment_.resize(0);
}
VLOG(1) << "Constructing " << name();
string precision_string;
OP_REQUIRES_OK(context,
context->GetAttr("precision_mode", &precision_string));
string calibration_data;
OP_REQUIRES_OK(context,
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
OP_REQUIRES_OK(context,
context->GetAttr("use_calibration", &use_calibration_));
calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE &&
calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
}
native_func_ = tensorflow::kInvalidHandle;
OP_REQUIRES_OK(context, context->GetAttr("max_cached_engines_count",
&max_cached_engines_));
OP_REQUIRES_OK(context, context->GetAttr("cached_engine_batches",
&cached_engine_batches_));
std::sort(cached_engine_batches_.begin(), cached_engine_batches_.end());
if (VLOG_IS_ON(1)) {
string s("Engine Batches= ");
for (auto i : cached_engine_batches_) {
StrAppend(&s, i, " ");
}
VLOG(1) << s;
}
}
void TRTEngineOp::ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper) {
std::vector<Tensor> inputs;
std::vector<Tensor>* outputs = new std::vector<Tensor>();
if (native_func_ == tensorflow::kInvalidHandle) {
auto status = ConstructFunctionHandle(ctx);
if (!status.ok()) {
LOG(ERROR) << "Couldn't construct function handle " << funcdef_name_;
ctx->SetStatus(status);
return;
}
}
auto lib = ctx->function_library();
tensorflow::FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.rendezvous = ctx->rendezvous();
opts.cancellation_manager = ctx->cancellation_manager();
opts.runner = ctx->runner();
for (int i = 0; i < ctx->num_inputs(); i++) {
inputs.push_back(ctx->input(i));
}
helper->Ref(); // Increment count for calculating native graph
VLOG(1) << "Executing native segment: " << name();
lib->Run(opts, native_func_, inputs, outputs,
[this, ctx, outputs, helper](const tensorflow::Status& s) {
tensorflow::core::ScopedUnref sc(helper);
if (!s.ok()) {
LOG(ERROR) << "Failed to execute native segment " << this->name()
<< ": " << s;
ctx->SetStatus(s);
return;
}
VLOG(1) << "Native Segment completed";
for (size_t t = 0; t < outputs->size(); ++t) {
ctx->set_output(t, outputs->at(t));
}
test::AddTestValue(StrCat(this->name(), ":ExecuteNativeSegment"),
"done");
delete outputs;
});
}
void TRTEngineOp::ExecuteCalibration(OpKernelContext* ctx,
AsyncHelper* helper) {
VLOG(1) << "Executing TRT calibration: " << name();
helper->Ref();
tensorflow::core::ScopedUnref sc(helper);
auto res_mgr = ctx->resource_manager();
TRTCalibrationResource* calib_res = nullptr;
OP_REQUIRES_OK(
ctx,
res_mgr->LookupOrCreate(
"TF_TRT_Calibration", name(),
reinterpret_cast<SerializableResourceBase**>(&calib_res),
{[ctx, this](SerializableResourceBase** cr) -> tensorflow::Status {
return this->AllocateCalibrationResources(ctx, cr);
}}));
tensorflow::core::ScopedUnref calib_sc(calib_res);
// TODO(aaroey): here we also add the resource to the ResourceMgr singleton.
// This is needed before we migrate all uses of calib_graph_to_infer_graph()
// to the new calibration workflow. After that we'll remove this block.
{
auto deprecated_rm =
TRTResourceManager::instance()->getManager("TRTCalibration");
TRTCalibrationResource* copied_resource = nullptr;
// Check whether the resource exists, and create it if not.
if (deprecated_rm->Lookup(funcdef_name_, "Calibrator", &copied_resource)
.ok()) {
// Do nothing if the resource exists.
copied_resource->Unref();
} else {
copied_resource = calib_res;
// Increase the refcount by 1 then transfer the ownership of that refcount
// to the ResourceMgr singleton.
copied_resource->Ref();
OP_REQUIRES_OK(ctx, deprecated_rm->Create(funcdef_name_, "Calibrator",
copied_resource));
}
}
int num_inputs = ctx->num_inputs();
// Pass input data to calibrator
std::unordered_map<string, void*> input_data;
for (int i = 0; i < num_inputs; i++) {
const Tensor& t = ctx->input(i);
void* data_address = GetTensorAddress(&t);
if (data_address == nullptr) {
ctx->SetStatus(tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i));
return;
}
// Check the allocated buffer is sufficient for input
const auto device_tensor =
calib_res->device_tensors_.at(i).AccessTensor(ctx);
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
input_data.emplace(StrCat(kInputPHName, i), data_address);
}
VLOG(2) << "Filled map for sending";
// copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
calib_res->calibrator_->setBatch(input_data, *stream);
test::AddTestValue(StrCat(name(), ":ExecuteCalibration"), "done");
VLOG(2) << "Passed calibration data";
ExecuteNativeSegment(ctx, helper);
}
bool TRTEngineOp::GetCompatibleCachedEngine(
const std::vector<TensorShape>& actual_input_shapes,
std::vector<TensorShape>* engine_input_shapes) {
const int batch_size = actual_input_shapes[0].dim_size(0);
int smallest_batch_size = -1;
// Output shape will always be the same as the input but we will overwrite the
// batch size.
*engine_input_shapes = actual_input_shapes;
for (const int cached_batch_size : cached_engine_batches_) {
// Check if compatible: batch <= cached batch.
//
// TODO(laigd): here it only compare the first dim a.k.a the batch size,
// we'll need to to support non-batch dimensions as well. This will be done
// as part of the offline conversion implementation.
if (batch_size <= cached_batch_size) {
// First case: first compatible engine found
// Second case: smaller batch size engine found
if ((smallest_batch_size == -1) ||
(cached_batch_size < smallest_batch_size)) {
smallest_batch_size = cached_batch_size;
// Overwrite batch size for output
for (int i = 0; i < engine_input_shapes->size(); i++) {
(*engine_input_shapes)[i].set_dim(0, smallest_batch_size);
}
}
}
}
return (smallest_batch_size != -1);
}
void TRTEngineOp::ComputeAsync(OpKernelContext* ctx,
AsyncOpKernel::DoneCallback done) {
auto helper = new AsyncHelper(done);
tensorflow::core::ScopedUnref sc(helper);
if (calibration_mode_) {
ExecuteCalibration(ctx, helper);
return;
}
// Get shapes of inputs to engine.
std::vector<tensorflow::TensorShape> input_shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
input_shapes.emplace_back(ctx->input(i).shape());
}
EngineContext* engine_context = GetEngine(input_shapes, ctx);
if (!engine_context->cuda_engine) {
LOG(WARNING) << "Engine retrieval for input shapes: "
<< TensorShapeUtils::ShapeListString(input_shapes)
<< " failed. Running native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
const bool retry = ExecuteTrtEngine(ctx, engine_context);
if (retry) {
LOG(WARNING) << "Failed to execute engine, "
<< "retrying with native segment for " << name();
ExecuteNativeSegment(ctx, helper);
return;
}
}
bool TRTEngineOp::ExecuteTrtEngine(OpKernelContext* ctx,
EngineContext* engine_context) {
VLOG(1) << "Executing TRT engine: " << name();
auto& cuda_engine = engine_context->cuda_engine;
const bool kRetry = true;
// All inputs must have the same batch size, so just get it from the first
// input.
const int num_batch = ctx->input(0).shape().dim_size(0);
const int num_binding = ctx->num_inputs() + ctx->num_outputs();
std::vector<void*> buffers(num_binding);
for (int i = 0; i < ctx->num_inputs(); i++) {
const string input_name = StrCat(kInputPHName, i);
const int binding_index = cuda_engine->getBindingIndex(input_name.c_str());
if (binding_index == -1) {
LOG(ERROR) << "Input node not found, at " << input_name;
return kRetry;
}
const Tensor& input_tensor = ctx->input(i);
const TensorShape& input_shape = input_tensor.shape();
if (num_batch != input_shape.dim_size(0)) {
LOG(ERROR) << "Input data has inconsistent batch size: " << num_batch
<< " vs " << input_shape.dim_size(0);
return kRetry;
}
auto dtype = cuda_engine->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] = (void*)(input_tensor.flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(ERROR) << "FP16 inputs are not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT8:
LOG(ERROR) << "INT8 inputs are not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT32:
buffers[binding_index] = (void*)(input_tensor.flat<int32>().data());
break;
default:
LOG(ERROR) << "Unknown TRT data type: " << int(dtype);
return kRetry;
}
}
for (int i = 0; i < ctx->num_outputs(); i++) {
// Create an output tensor
const string output_name = StrCat(kOutputPHName, i);
const int binding_index = cuda_engine->getBindingIndex(output_name.c_str());
Tensor* output_tensor = nullptr;
TensorShape output_shape;
if (binding_index != -1) {
auto dims = cuda_engine->getBindingDimensions(binding_index);
std::vector<int> trt_shape(dims.nbDims + 1);
trt_shape[0] = num_batch;
for (int j = 0; j < dims.nbDims; j++) trt_shape[j + 1] = dims.d[j];
auto status = TensorShapeUtils::MakeShape(
trt_shape.data(), trt_shape.size(), &output_shape);
if (!status.ok()) {
LOG(ERROR) << "Failed to get output shape: " << status;
return kRetry;
}
} else {
LOG(ERROR) << "Output node not found, at " << output_name;
return kRetry;
}
auto status = ctx->allocate_output(i, output_shape, &output_tensor);
if (!status.ok()) {
LOG(ERROR) << "Allocating output failed with " << status;
ctx->SetStatus(status);
// Do not retry since we cannot allocate the same output twice.
// TODO(aaroey): ideally we should retry, fix this.
return !kRetry;
}
auto dtype = cuda_engine->getBindingDataType(binding_index);
switch (dtype) {
case nvinfer1::DataType::kFLOAT:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<float>().data());
break;
case nvinfer1::DataType::kHALF:
LOG(WARNING) << "half size is not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT8:
LOG(WARNING) << "int8 is not supported yet!";
return kRetry;
case nvinfer1::DataType::kINT32:
buffers[binding_index] =
reinterpret_cast<void*>(output_tensor->flat<int32>().data());
break;
default:
LOG(WARNING) << "Unknown TRT data type: " << static_cast<int>(dtype);
return kRetry;
}
}
// Copied from cuda_kernel_helper since it seems only valid in *.cu.cc files
const cudaStream_t* stream = CHECK_NOTNULL(
reinterpret_cast<const cudaStream_t*>(ctx->op_device_context()
->stream()
->implementation()
->GpuStreamMemberHack()));
// nvinfer1::IExecutionContext::enqueue is not thread safe and we need a mutex
// for it.
tensorflow::mutex_lock lock(engine_context->mu);
// TODO(jie): trt enqueue does not return error
auto ret = engine_context->execution_context->enqueue(num_batch, &buffers[0],
*stream, nullptr);
if (!ret) {
LOG(WARNING) << "Failed to enqueue batch for TRT engine: " << name();
return kRetry;
}
test::AddTestValue(StrCat(name(), ":ExecuteTrtEngine"), "done");
// Synchronization will be done by TF.
return !kRetry;
}
EngineContext* TRTEngineOp::GetEngine(
const std::vector<TensorShape>& input_shapes, OpKernelContext* ctx) {
static EngineContext empty_context;
tensorflow::mutex_lock lock(engine_mutex_);
// TODO(tmorris): using first input to get batch size - is this reliable?
const int batch_size = input_shapes[0].dim_size(0);
// Get engine cache
TRTEngineCacheResource* cache_res = nullptr;
auto status = ctx->resource_manager()->LookupOrCreate(
"TRTEngineCache", funcdef_name_, &cache_res,
{[this, ctx](TRTEngineCacheResource** cr) -> tensorflow::Status {
*cr = new TRTEngineCacheResource(ctx, this->max_cached_engines_);
return Status::OK();
}});
if (!status.ok()) {
ctx->SetStatus(status);
return &empty_context;
}
tensorflow::core::ScopedUnref sc(cache_res);
auto& cache = cache_res->cache_;
auto allocator = cache_res->allocator_.get();
if (allocator == nullptr) {
return &empty_context;
}
// Handle the static engine case. For static engines, the cache will have a
// single element containing the only engine.
if (static_engine_) {
if (cache.size()) {
// Batch size of engine must be >= the input batch size
// TODO(tmorris): use match compatible function?
if (cache.begin()->first[0].dim_size(0) >= batch_size) {
return cache.begin()->second.get();
}
return &empty_context;
}
TrtUniquePtrType<IRuntime> infer(nvinfer1::createInferRuntime(logger));
infer->setGpuAllocator(allocator);
TrtUniquePtrType<nvinfer1::ICudaEngine> static_engine(
infer->deserializeCudaEngine(serialized_segment_.c_str(),
serialized_segment_.size(),
PluginFactoryTensorRT::GetInstance()));
auto raw_static_engine = static_engine.get();
const auto max_batch_size = raw_static_engine->getMaxBatchSize();
// Static engine will have max_batch_size for batch size so that all inputs
// will map to this single engine.
std::vector<TensorShape> engine_input_shapes(input_shapes);
for (int i = 0; i < engine_input_shapes.size(); i++) {
// TODO(tmorris): will all inputs have batch size as first dimension??
engine_input_shapes[i].set_dim(0, max_batch_size);
}
// TODO(laigd): here we assume engine_input_shapes matches the actual input
// shapes of the engine, we should verify that.
cache.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(
std::move(static_engine),
TrtUniquePtrType<nvinfer1::IExecutionContext>(
raw_static_engine->createExecutionContext())));
// Runtime is safe to delete after engine creation
serialized_segment_.clear();
if (max_batch_size < batch_size) {
return &empty_context;
}
return cache.at(engine_input_shapes).get();
} // static_engine_
// Handle the dynamic engine case.
// See if there is a compatible engine cached. The batch size should be <= the
// cached batch size.
std::vector<tensorflow::TensorShape> engine_input_shapes;
const bool matched_successfully =
GetCompatibleCachedEngine(input_shapes, &engine_input_shapes);
// If matched, use that engine. Otherwise, we will look in cache for that
// exact shape and possibly create a new engine if it is not in cache.
if (!matched_successfully) {
engine_input_shapes = input_shapes;
if (!cached_engine_batches_.empty()) {
// If user has explicitly defined cached_engine_batches, we should
// warn them that their input was non-compatible (batch size too high)
LOG(WARNING) << "No compatible cached engine was found for batch size: "
<< batch_size << ". A new engine will be created.";
cached_engine_batches_.push_back(batch_size);
}
}
if (!cache.count(engine_input_shapes)) {
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
bool convert_successfully = false;
LOG(INFO) << "Building a new TensorRT engine for " << name()
<< " input shapes: "
<< TensorShapeUtils::ShapeListString(engine_input_shapes);
// Convert to partial shapes
std::vector<PartialTensorShape> partial_shapes;
for (int i = 0; i < engine_input_shapes.size(); i++) {
partial_shapes.emplace_back(engine_input_shapes[i]);
}
// Up to this point, calibrator_ can never be empty, since otherwise it
// means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
segment_graph_, precision_mode_, batch_size, workspace_size_,
partial_shapes, &logger, allocator, calibrator_.get(), &engine,
use_calibration_, &convert_successfully);
if (!status.ok()) {
if (convert_successfully) {
// This means it fail to build the engine even when the network is built
// successfully, probably due to internal issues. In this case we don't
// retry in the future.
cache.emplace(engine_input_shapes, absl::make_unique<EngineContext>());
}
LOG(WARNING) << "Engine creation for batch size " << batch_size
<< " failed " << status;
return &empty_context;
}
VLOG(1) << "Conversion is done";
TrtUniquePtrType<nvinfer1::IExecutionContext> exec_context(
engine->createExecutionContext());
cache.emplace(engine_input_shapes,
absl::make_unique<EngineContext>(std::move(engine),
std::move(exec_context)));
}
return cache.at(engine_input_shapes).get();
}
tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
OpKernelContext* ctx, SerializableResourceBase** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
// Get the allocator.
auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(WARNING) << "Can't get device allocator will not be able to "
"allocate memory from TensorFlow memory pool";
cres->allocator_.reset(new TRTCudaAllocator);
} else {
cres->allocator_.reset(new TRTDeviceAllocator(alloc));
}
// Get the input shapes.
const int batch_size = ctx->input(0).dim_size(0);
const int num_inputs = ctx->num_inputs();
std::vector<tensorflow::PartialTensorShape> shapes;
cres->device_tensors_.resize(num_inputs);
VLOG(1) << " Constructing calibrator";
for (int i = 0; i < num_inputs; i++) {
// allocate workspace on device for inputs
const tensorflow::Tensor& t = ctx->input(i);
shapes.emplace_back(t.shape());
Tensor* device_tensor;
TF_RETURN_IF_ERROR(ctx->allocate_persistent(
t.dtype(), t.shape(), &cres->device_tensors_.at(i), &device_tensor));
CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes());
void* device_address = GetTensorAddress(device_tensor);
if (device_address == nullptr) {
return tensorflow::errors::InvalidArgument(
"Unsupported data type encountered in input ", i);
}
cres->device_buffers_.emplace(
StrCat(kInputPHName, i),
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
}
cres->calibrator_.reset(
new TRTInt8Calibrator(cres->device_buffers_, batch_size, name()));
const string label(name());
auto segment_graph = &segment_graph_;
const int platform_gpu_id =
ctx->device()->tensorflow_gpu_device_info()->gpu_id;
if (platform_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
const int64 workspace_size_bytes = workspace_size_;
cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
platform_gpu_id, workspace_size_bytes]() {
LOG(INFO) << "Starting calibration thread on device " << platform_gpu_id
<< ", Calibration Resource @ " << cres;
auto err = cudaSetDevice(platform_gpu_id);
if (err != cudaSuccess) {
// TODO(aaroey): should return error here.
LOG(ERROR) << "Couldn't set cuda device to " << platform_gpu_id
<< " in calibration thread";
}
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator returns
// false. Engine is discarded after calibration table is generated
//
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
*segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
cres->calibrator_.get(), &cres->engine_,
/*use_calibration=*/true,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
LOG(ERROR) << "Calibration failed: " << s;
cres->calibrator_->setDone(); // Ignore further pushes
}
VLOG(1) << "Calibration loop terminated " << label;
}));
VLOG(1) << "initialized calibrator resource";
return tensorflow::Status::OK();
}
REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp);
} // namespace tensorrt
} // namespace tensorflow
#endif // GOOGLE_TENSORRT
#endif // GOOGLE_CUDA