Skip to content

Commit

Permalink
TF-TRT C++ interface for conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Oct 8, 2021
1 parent 0cafc80 commit 5cbde12
Show file tree
Hide file tree
Showing 6 changed files with 758 additions and 0 deletions.
57 changes: 57 additions & 0 deletions tensorflow/compiler/tf2tensorrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,62 @@ tf_cuda_cc_test(
]),
)

cc_library(
name = "trt_convert_api",
srcs = ["trt_convert_api.cc"],
hdrs = [
"trt_convert_api.h",
],
copts = tf_copts(),
deps = [
":trt_resources",
"@com_google_absl//absl/strings",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/platform:logging",
] + if_tensorrt([":tensorrt_lib"]),
)

tf_cuda_cc_test(
name = "trt_convert_api_test",
size = "small",
srcs = ["trt_convert_api_test.cc"],
tags = [
"no_cuda_on_cpu_tap",
"no_windows",
"nomac",
],
deps = [
":common_utils",
":testutils",
":trt_convert_api",
":trt_conversion",
":trt_logging",
":trt_op_kernels",
":trt_resources",
":utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:resource_variable_ops",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:no_op_op_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:state_ops_op_lib",
],
)

cc_library(
name = "common_utils",
srcs = ["common/utils.cc"],
Expand Down Expand Up @@ -279,6 +335,7 @@ cc_library(
name = "trt_op_libs",
deps = [
":get_calibration_data_op_op_lib",
":trt_convert_api",
":trt_engine_op_op_lib",
":trt_engine_utils",
],
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ absl::optional<DeviceNameUtils::ParsedName> MergeIfCompatible(
const DeviceNameUtils::ParsedName& a, absl::string_view b);

// Optimization profile generation strategies.
// - `kRange`: create one profile that works for inputs with dimension values
// in the range of [min_dims, max_dims] where min_dims and max_dims are
// derived from the provided inputs.
// - `kOptimal`: create one profile for each input. The profile only works for
// inputs with the same dimensions as the input it is created for. The GPU
// engine will be run with optimal performance with such inputs.
// - `kRangeOptimal`: create the profiles for both `Range` and `Optimal`.
// - `kImplicitBatchModeCompatible`: create the profiles that will produce the
// same GPU engines as the implicit_batch_mode would produce.
enum class ProfileStrategy {
kRange,
kOptimal,
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ TRTEngineOp::TRTEngineOp(OpKernelConstruction* context)
<< ", thus setting _profile_generation_mode=false";
profile_generation_mode_ = false;
}
if (static_engine_) {
if (profile_generation_mode_) profile_generation_mode_ = false;
}
if (use_implicit_batch_) {
OP_REQUIRES(context, !profile_generation_mode_,
errors::InvalidArgument(
Expand Down

0 comments on commit 5cbde12

Please sign in to comment.