Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ExecuteTrtEngine. #38118

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/compiler/tf2tensorrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ cc_library(
deps = [
":trt_allocator",
":trt_conversion",
":trt_engine_utils",
":trt_logging",
":trt_plugins",
":trt_resources",
Expand Down Expand Up @@ -215,9 +216,24 @@ cc_library(
deps = [
":get_calibration_data_op_op_lib",
":trt_engine_op_op_lib",
":trt_engine_utils",
],
)

tf_cuda_library(
name = "trt_engine_utils",
srcs = ["utils/trt_engine_utils.cc"],
hdrs = ["utils/trt_engine_utils.h"],
deps = [
":trt_logging",
":utils",
"@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:status",
] + if_tensorrt([":tensorrt_lib"]),
)

tf_cuda_library(
name = "trt_logging",
srcs = ["utils/trt_logger.cc"],
Expand Down
46 changes: 27 additions & 19 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor.h"
Expand Down Expand Up @@ -1213,26 +1215,12 @@ TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
TF_EXPECT_OK(RunConvertGraphDefToEngine(&s));
}

// Input/output data format for OpConverterTest::BuildAndRun().
struct InputOutputData {
void* Buffer() const {
return const_cast<char*>(tensor.tensor_data().data());
}

size_t TotalBytes() const { return tensor.TotalBytes(); }

string name;
Tensor tensor;
};

template <typename T>
Tensor ConstructTensor(int data_size, const T& value = T()) {
std::vector<T> values(data_size, value);
return test::AsTensor<T>(values);
}

using DataVec = std::vector<InputOutputData>;

template <typename T>
inline absl::Span<const T> GetSpanForData(const InputOutputData& data) {
const auto& tensor_map = data.tensor.flat<T>();
Expand Down Expand Up @@ -1308,10 +1296,31 @@ class OpConverterTest : public ::testing::Test {
CheckDataTypeMatches(input_data);
CheckDataTypeMatches(*output_data);

// Execute the TRT engine.
const int num_bindings = input_data.size() + output_data->size();
std::vector<void*> buffers(num_bindings);

ASSERT_EQ(engine_->getNbBindings(), num_bindings);
TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
engine_->createExecutionContext());

// Prepare input bindings.
TF_ASSERT_OK(SetTrtEngineInputs(engine_.get(), execution_context.get(), 0,
buffers, converter_->use_implicit_batch(),
batch_size, nullptr, &input_data));

// Prepare output bindings.
TF_ASSERT_OK(SetTrtEngineOutputs(engine_.get(), execution_context.get(), 0,
buffers, converter_->use_implicit_batch(),
batch_size, nullptr, output_data));

// Allocate buffers on GPU and copy data there. This is necessary because
// the test tensors are allocated in host memory, so the pointers that
// SetTrtEngin(In|Out)puts placed into buffers[] cannot be used on the GPU.
// We allocate the GPU buffers, copy the data there, and overwrite the
// addresses in the buffers array.
//
// TODO(tfeher): This step can be avoided if we allocate the Tensors in
// unified memory.
for (const auto& data : input_data) {
const int input_index = engine_->getBindingIndex(data.name.c_str());
ASSERT_NE(-1, input_index);
Expand All @@ -1334,10 +1343,9 @@ class OpConverterTest : public ::testing::Test {
ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes()));
}

ASSERT_EQ(engine_->getNbBindings(), num_bindings);
TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
engine_->createExecutionContext());
execution_context->enqueue(batch_size, buffers.data(), stream_, nullptr);
// Execute the TRT engine.
TF_ASSERT_OK(TrtEnqueue(execution_context.get(), buffers, stream_,
converter_->use_implicit_batch(), batch_size));

for (int i = 0; i < output_infos.size(); ++i) {
const auto& output_info = output_infos[i];
Expand Down
22 changes: 22 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,28 @@ bool AreShapesCompatible(const std::vector<TensorShape>& actual_shapes,
return true;
}

Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
bool use_implicit_batch, int batch_size,
TensorShape& shape) {
TF_RETURN_IF_ERROR(
TensorShapeUtils::MakeShape(trt_dims.data(), trt_dims.size(), &shape));
if (use_implicit_batch) {
shape.InsertDim(0, batch_size);
}
return Status::OK();
}

Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
bool use_implicit_batch, int batch_size,
TensorShape& shape) {
TF_RETURN_IF_ERROR(
TensorShapeUtils::MakeShape(trt_dims.d, trt_dims.nbDims, &shape));
if (use_implicit_batch) {
shape.InsertDim(0, batch_size);
}
return Status::OK();
}

int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) {
int n_bindings = engine->getNbBindings();
int n_input = 0;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/tf2tensorrt/convert/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ inline nvinfer1::Dims TensorShapeToTrtDims(const TensorShapeType& shape,
return trt_dims;
}

Status TrtDimsToTensorShape(const std::vector<int>& trt_dims,
bool use_implicit_batch, int batch_size,
TensorShape& shape);

Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims,
bool use_implicit_batch, int batch_size,
TensorShape& shape);

// Returns a string that includes compile time TensorRT library version
// information {Maj, Min, Patch}.
string GetLinkedTensorRTVersion();
Expand Down