Skip to content

Commit

Permalink
TF-TRT C++ interface for conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Sep 26, 2021
1 parent 0f55a2c commit d876eae
Show file tree
Hide file tree
Showing 6 changed files with 778 additions and 0 deletions.
56 changes: 56 additions & 0 deletions tensorflow/compiler/tf2tensorrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,62 @@ tf_cuda_cc_test(
]),
)

cc_library(
name = "trt_api",
srcs = ["trt_convert.cc"],
hdrs = [
"trt_convert.h",
],
copts = tf_copts(),
deps = [
":trt_resources",
"//tensorflow/cc/saved_model:loader",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/platform:logging",
] + if_tensorrt([":tensorrt_lib"]),
)

tf_cuda_cc_test(
name = "trt_api_test",
size = "small",
srcs = ["trt_convert_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":common_utils",
":testutils",
":trt_api",
":trt_conversion",
":trt_logging",
":trt_op_kernels",
":trt_resources",
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:no_op_op_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:state_ops_op_lib",
],
)

cc_library(
name = "common_utils",
srcs = ["common/utils.cc"],
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
<< ", thus setting _profile_generation_mode=false";
profile_generation_mode_ = false;
}
if (static_engine_) {
if (profile_generation_mode_) profile_generation_mode_ = false;
}
if (use_implicit_batch_) {
OP_REQUIRES(context, !profile_generation_mode_,
errors::InvalidArgument(
Expand Down

0 comments on commit d876eae

Please sign in to comment.