Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let segmenter able to detect invalid input/output segment nodes beforehand. #20755

Merged
merged 6 commits into from
Jul 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorflow/contrib/tensorrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,15 @@ tf_cc_test(
tags = ["no_windows"],
deps = [
":segment",
"//tensorflow/c:c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)

Expand Down
43 changes: 24 additions & 19 deletions tensorflow/contrib/tensorrt/convert/convert_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/protobuf/config.pb.h" // NOLINT
#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT
#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
#include "tensorflow/core/protobuf/rewriter_config.pb.h" // NOLINT
#include "tensorflow/core/util/device_name_utils.h"

#if GOOGLE_CUDA
Expand Down Expand Up @@ -301,7 +301,8 @@ tensorflow::Status GetEngineInfo(
const int node_id = node->id();
for (const auto edge : node->in_edges()) {
auto input_node = edge->src();
if (segment_nodes.count(input_node->name()) == 0) {
if (segment_nodes.count(input_node->name()) == 0 &&
!edge->IsControlEdge() && !input_node->IsSource()) {
// Add constant input node into the segment. We don't care if it has
// other output edges going into other engines or TF nodes. Since we add
// it only to the subsegment node list, not the subsegment itself, it
Expand All @@ -312,7 +313,7 @@ tensorflow::Status GetEngineInfo(
added_const_node_ids.insert(input_node->id());
subgraph_node_ids.push_back(input_node->id());
}
} else if (!edge->IsControlEdge() && !input_node->IsSource()) {
} else {
string s(input_node->name());
StrAppend(&s, ":", edge->src_output());
VLOG(1) << "Input edge = " << s;
Expand Down Expand Up @@ -378,9 +379,9 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
nvinfer1::IGpuAllocator* alloc,
int max_batch_size) {
const auto& info = infos.at(pos);
std::vector<tensorflow::TensorShapeProto> out_shapes;
std::vector<tensorflow::TensorShapeProto> input_shapes;
std::vector<tensorflow::PartialTensorShape> shapes;
std::vector<tensorflow::TensorShapeProto> output_shape_protos;
std::vector<tensorflow::TensorShapeProto> input_shape_protos;
std::vector<tensorflow::PartialTensorShape> input_shapes;
std::vector<tensorflow::NodeDefBuilder::NodeOut> inputs;
std::vector<tensorflow::DataType> out_types;
VLOG(1) << "Processing " << info.engine_name;
Expand All @@ -393,24 +394,24 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
tensorflow::TensorShapeProto out_shape;
// shape of the output node inside segment
conn.inside_shape.AsProto(&out_shape);
if (out_shapes.size() <= conn.port_number) {
out_shapes.resize(conn.port_number + 1);
if (output_shape_protos.size() <= conn.port_number) {
output_shape_protos.resize(conn.port_number + 1);
out_types.resize(conn.port_number + 1);
}
out_shapes.at(conn.port_number) = out_shape;
output_shape_protos.at(conn.port_number) = out_shape;
out_types.at(conn.port_number) = conn.connection_type;
continue;
}

// Set the shapes and data types of input edge.
tensorflow::TensorShapeProto in_shape;
conn.outside_shape.AsProto(&in_shape);
if (input_shapes.size() <= conn.port_number) {
if (input_shape_protos.size() <= conn.port_number) {
input_shape_protos.resize(conn.port_number + 1);
input_shapes.resize(conn.port_number + 1);
shapes.resize(conn.port_number + 1);
}
input_shapes.at(conn.port_number) = in_shape;
shapes.at(conn.port_number) = conn.outside_shape;
input_shape_protos.at(conn.port_number) = in_shape;
input_shapes.at(conn.port_number) = conn.outside_shape;

string input_node = conn.outside_node_name;
int input_port = conn.outside_port;
Expand Down Expand Up @@ -438,6 +439,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
VLOG(1) << "Engine Input " << input_node << ":" << input_port << " -> "
<< info.engine_name << ":" << inputs.size();
// Skip duplicate inputs.
// TODO(aaroey): use std::find instead. GetEngineInfo already remove
// duplicate connections, so here we should never find any duplicate?
bool new_input = true;
for (const auto& inp : inputs) {
if (inp.node == input_node && inp.index == input_port) {
Expand Down Expand Up @@ -465,8 +468,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
info.segment_graph_def,
info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
alloc, /*calibrator=*/nullptr, &engine,
max_batch_size, info.max_workspace_size_bytes, input_shapes,
&trt_logger, alloc, /*calibrator=*/nullptr, &engine,
/*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
Expand Down Expand Up @@ -514,8 +517,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
}
tensorflow::NodeDef trt_node;
tensorflow::Status status =
node_builder.Attr("input_shapes", input_shapes)
.Attr("output_shapes", out_shapes)
node_builder.Attr("input_shapes", input_shape_protos)
.Attr("output_shapes", output_shape_protos)
.Attr("static_engine",
info.engine_type == EngineInfo::EngineType::TRTStatic)
.Attr("segment_funcdef_name",
Expand Down Expand Up @@ -734,6 +737,7 @@ std::pair<int, tensorflow::Allocator*> GetDeviceAndAllocator(
}

// Entry function from optimization pass.
// TODO(aaeory): parameter should use pointer type.
tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
// Convert graphdef to graph.
tensorflow::FunctionLibraryDefinition flib(tensorflow::OpRegistry::Global(),
Expand All @@ -751,7 +755,8 @@ tensorflow::Status ConvertAfterShapes(ConversionParams& params) {
segment_options.minimum_segment_size = params.minimum_segment_size;
tensorflow::tensorrt::segment::SegmentNodesVector initial_segments;
TF_RETURN_IF_ERROR(tensorrt::segment::SegmentGraph(
&graph, IsTensorRTCandidate, segment_options, &initial_segments));
&graph, IsTensorRTCandidate, InputEdgeValidator(*params.graph_properties),
OutputEdgeValidator(), segment_options, &initial_segments));
if (initial_segments.size() > 1) {
VLOG(0) << "MULTIPLE tensorrt candidate conversion: "
<< initial_segments.size();
Expand Down
168 changes: 116 additions & 52 deletions tensorflow/contrib/tensorrt/convert/convert_nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"

#include <algorithm>
#include <cstring>
#include <list>
#include <map>
#include <memory>
Expand Down Expand Up @@ -77,7 +78,6 @@ namespace tensorflow {
namespace tensorrt {
namespace convert {
using ::tensorflow::str_util::Split;

using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;

Expand Down Expand Up @@ -107,6 +107,59 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
return tensorflow::Status::OK();
}

void GetInputProperties(const grappler::GraphProperties& graph_properties,
const Node* outside_node, const int out_port,
PartialTensorShape* shape,
tensorflow::DataType* dtype) {
if (graph_properties.HasOutputProperties(outside_node->name())) {
auto output_params =
graph_properties.GetOutputProperties(outside_node->name());
auto out_shape = output_params.at(out_port);
*dtype = out_shape.dtype();
*shape = out_shape.shape();
} else {
VLOG(0) << "Unknown output shape" << outside_node->name();
*dtype = outside_node->output_type(out_port);
}
}

void GetOutputProperties(const grappler::GraphProperties& graph_properties,
const Node* outside_node, const int in_port,
PartialTensorShape* shape,
tensorflow::DataType* dtype) {
if (graph_properties.HasInputProperties(outside_node->name())) {
auto input_params =
graph_properties.GetInputProperties(outside_node->name());
auto in_shape = input_params.at(in_port);
*dtype = in_shape.dtype();
*shape = in_shape.shape();
} else {
*dtype = outside_node->input_type(in_port);
}
}

tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape,
const tensorflow::DataType dtype,
nvinfer1::DataType* trt_dtype) {
// TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so
// put them there instead.
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype));
if (shape.dims() < 0) {
return tensorflow::errors::InvalidArgument("Input tensor rank is unknown.");
}
if (shape.dims() > 8) {
return tensorflow::errors::OutOfRange(
"Input tensor rank is greater than 8.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TRT supports tensor with 8 dimension EXCLUDING batch dimension. So we should be able to support shape.dims() <= 9.
It is possible that the old code got the check condition wrong.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out. I'll fix this before this is merged.

}
for (int d = 1; d < shape.dims(); ++d) {
if (shape.dim_size(d) < 0) {
return tensorflow::errors::InvalidArgument(
"Input tensor has a unknown non-batch dimemension at dim ", d);
}
}
return Status::OK();
}

// Return whether or not the broadcast is feasible;
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l,
const bool operand_l_is_tensor,
Expand Down Expand Up @@ -2640,25 +2693,22 @@ tensorflow::Status ConvertGraphDefToEngine(
(node_def.op() == "Placeholder")) {
nvinfer1::DimsCHW input_dim_pseudo_chw;
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0;
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
auto type_status =
ConvertDType(node_def.attr().at("dtype").type(), &dtype);
if (type_status != tensorflow::Status::OK()) {
LOG(WARNING) << "Type conversion failed for " << node_name;
return type_status;
}
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8,
&slot_number)) {
LOG(ERROR) << "Failed to parse slot number from " << node_name
<< " +8= " << node_name.c_str() + 8;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
return tensorflow::errors::InvalidArgument(
"Failed to parse slot number from ", node_name);
}
nvinfer1::DataType dtype;
auto shape = input_shapes.at(slot_number);
if (shape.dims() > 8) {
LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name
<< " at input slot " << slot_number;
return tensorflow::errors::OutOfRange(
"Input tensor rank is greater than 8");
auto status = ValidateInputProperties(
shape, node_def.attr().at("dtype").type(), &dtype);
if (!status.ok()) {
const string error_message =
StrCat("Validation failed for ", node_name, " and input slot ",
slot_number, ": ", status.error_message());
LOG(WARNING) << error_message;
return Status(status.code(), error_message);
}
if (VLOG_IS_ON(1)) {
string dim_str("dims=");
Expand Down Expand Up @@ -2689,10 +2739,10 @@ tensorflow::Status ConvertGraphDefToEngine(
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
(node_def.op() == "Identity")) {
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9,
&slot_number)) {
LOG(ERROR) << "Failed to parse slot number from " << node_name
<< " +9=" << node_name.c_str() + 9;
if (!tensorflow::strings::safe_strto32(
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
return tensorflow::errors::InvalidArgument(
"Failed to parse slot number from ", node_name);
}
if (output_tensors.size() <= slot_number) {
output_tensors.resize(slot_number + 1);
Expand Down Expand Up @@ -2751,38 +2801,20 @@ tensorflow::Status ConvertSegmentToGraphDef(
"Cannot find node with id ", connection.outside_id, " in the graph.");
}
// Updates the shape and data types of input/output connections.
tensorflow::DataType input_type = tensorflow::DT_FLOAT;
tensorflow::DataType dtype;
tensorflow::PartialTensorShape partial_shape;
if (connection.is_input_edge) {
if (graph_properties.HasOutputProperties(connection.outside_node_name)) {
auto output_params =
graph_properties.GetOutputProperties(connection.outside_node_name);
auto out_shape = output_params.at(connection.outside_port);
input_type = out_shape.dtype();
std::vector<tensorflow::int64> dims;
partial_shape = out_shape.shape();
connection.outside_shape = partial_shape;
} else {
VLOG(0) << "Unknown output shape" << outside_node->name();
input_type = graph->FindNodeId(connection.outside_id)
->output_type(connection.outside_port);
}
connection.connection_type = input_type;

} else { // output edge
if (graph_properties.HasInputProperties(connection.outside_node_name)) {
auto input_params =
graph_properties.GetInputProperties(connection.outside_node_name);
auto in_shape = input_params.at(connection.outside_port);
input_type = in_shape.dtype();
partial_shape = in_shape.shape();
connection.inside_shape = partial_shape;
} else {
input_type = graph->FindNodeId(connection.inside_id)
->output_type(connection.outside_port);
}
connection.connection_type = input_type;
GetInputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);

} else {
GetOutputProperties(graph_properties,
graph->FindNodeId(connection.outside_id),
connection.outside_port, &partial_shape, &dtype);
}
connection.outside_shape = partial_shape;
connection.connection_type = dtype;

// Add dummy input/output nodes to the segment graphdef.
if (connection.is_input_edge) {
Expand All @@ -2798,7 +2830,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Placeholder");
auto status = builder.Attr("shape", partial_shape)
.Attr("dtype", input_type)
.Attr("dtype", dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing input " << node_name << " for the edge "
<< connection.outside_node_name << ":" << connection.outside_port
Expand All @@ -2816,7 +2848,7 @@ tensorflow::Status ConvertSegmentToGraphDef(
marker_nodes.insert(node_name);
auto seg_node = segment_def->add_node();
tensorflow::NodeDefBuilder builder(node_name, "Identity");
auto status = builder.Input(connection.inside_node_name, 0, input_type)
auto status = builder.Input(connection.inside_node_name, 0, dtype)
.Finalize(seg_node);
VLOG(1) << "Constructing output " << node_name << " for the edge "
<< connection.inside_node_name << ":" << connection.inside_port
Expand Down Expand Up @@ -2854,6 +2886,38 @@ tensorflow::Status ConvertSegmentToGraphDef(
return tensorflow::Status::OK();
}

bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const {
if (in_edge->IsControlEdge()) return true;
PartialTensorShape shape;
tensorflow::DataType dtype;
GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(),
&shape, &dtype);
nvinfer1::DataType trt_dtype;
Status status = ValidateInputProperties(shape, dtype, &trt_dtype);
if (!status.ok()) {
VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
<< ": " << status;
return false;
}
if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape.dims() < 3?
Not quite sure about the check.

TRT 3.0 requires input to be of rank 4 (batch + 3 dimension). This requirement has been removed in TRT 4.0.
If this check is related to that, we need to add a compile guard.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, I'll leave this as is since there are more places to change besides this one regarding the 3 dims requirement. And we can issue another PR to fix the problem.

VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name()
<< " which has an input at port " << in_edge->dst_input()
<< " with #dim<3 and is not a const: " << shape;
return false;
}
return true;
}

bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const {
if (out_edge->IsControlEdge()) return true;
if (out_edge->src()->type_string() == "Const") {
VLOG(2) << "--> Need to remove output node " << out_edge->src()->name()
<< " which is a Const.";
return false;
}
return true;
}

} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
Expand Down