-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Let segmenter able to detect invalid input/output segment nodes beforehand. #20755
Changes from all commits
571d3dc
86f632e
f340242
e02fbb2
571f7a2
482b056
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ limitations under the License. | |
#include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" | ||
|
||
#include <algorithm> | ||
#include <cstring> | ||
#include <list> | ||
#include <map> | ||
#include <memory> | ||
|
@@ -77,7 +78,6 @@ namespace tensorflow { | |
namespace tensorrt { | ||
namespace convert { | ||
using ::tensorflow::str_util::Split; | ||
|
||
using ::tensorflow::strings::StrAppend; | ||
using ::tensorflow::strings::StrCat; | ||
|
||
|
@@ -107,6 +107,59 @@ inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, | |
return tensorflow::Status::OK(); | ||
} | ||
|
||
void GetInputProperties(const grappler::GraphProperties& graph_properties, | ||
const Node* outside_node, const int out_port, | ||
PartialTensorShape* shape, | ||
tensorflow::DataType* dtype) { | ||
if (graph_properties.HasOutputProperties(outside_node->name())) { | ||
auto output_params = | ||
graph_properties.GetOutputProperties(outside_node->name()); | ||
auto out_shape = output_params.at(out_port); | ||
*dtype = out_shape.dtype(); | ||
*shape = out_shape.shape(); | ||
} else { | ||
VLOG(0) << "Unknown output shape" << outside_node->name(); | ||
*dtype = outside_node->output_type(out_port); | ||
} | ||
} | ||
|
||
void GetOutputProperties(const grappler::GraphProperties& graph_properties, | ||
const Node* outside_node, const int in_port, | ||
PartialTensorShape* shape, | ||
tensorflow::DataType* dtype) { | ||
if (graph_properties.HasInputProperties(outside_node->name())) { | ||
auto input_params = | ||
graph_properties.GetInputProperties(outside_node->name()); | ||
auto in_shape = input_params.at(in_port); | ||
*dtype = in_shape.dtype(); | ||
*shape = in_shape.shape(); | ||
} else { | ||
*dtype = outside_node->input_type(in_port); | ||
} | ||
} | ||
|
||
tensorflow::Status ValidateInputProperties(const PartialTensorShape& shape, | ||
const tensorflow::DataType dtype, | ||
nvinfer1::DataType* trt_dtype) { | ||
// TODO(aaroey): some of these checks also apply to IsTensorRTCandidate(), so | ||
// put them there instead. | ||
TF_RETURN_IF_ERROR(ConvertDType(dtype, trt_dtype)); | ||
if (shape.dims() < 0) { | ||
return tensorflow::errors::InvalidArgument("Input tensor rank is unknown."); | ||
} | ||
if (shape.dims() > 8) { | ||
return tensorflow::errors::OutOfRange( | ||
"Input tensor rank is greater than 8."); | ||
} | ||
for (int d = 1; d < shape.dims(); ++d) { | ||
if (shape.dim_size(d) < 0) { | ||
return tensorflow::errors::InvalidArgument( | ||
"Input tensor has a unknown non-batch dimemension at dim ", d); | ||
} | ||
} | ||
return Status::OK(); | ||
} | ||
|
||
// Return whether or not the broadcast is feasible; | ||
bool TensorRTGetBroadcastShape(const nvinfer1::Dims& operand_l, | ||
const bool operand_l_is_tensor, | ||
|
@@ -2640,25 +2693,22 @@ tensorflow::Status ConvertGraphDefToEngine( | |
(node_def.op() == "Placeholder")) { | ||
nvinfer1::DimsCHW input_dim_pseudo_chw; | ||
for (int i = 0; i < 8; i++) input_dim_pseudo_chw.d[i] = 0; | ||
nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); | ||
auto type_status = | ||
ConvertDType(node_def.attr().at("dtype").type(), &dtype); | ||
if (type_status != tensorflow::Status::OK()) { | ||
LOG(WARNING) << "Type conversion failed for " << node_name; | ||
return type_status; | ||
} | ||
int32 slot_number = -1; | ||
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 8, | ||
&slot_number)) { | ||
LOG(ERROR) << "Failed to parse slot number from " << node_name | ||
<< " +8= " << node_name.c_str() + 8; | ||
if (!tensorflow::strings::safe_strto32( | ||
node_name.c_str() + strlen(kInputPHName), &slot_number)) { | ||
return tensorflow::errors::InvalidArgument( | ||
"Failed to parse slot number from ", node_name); | ||
} | ||
nvinfer1::DataType dtype; | ||
auto shape = input_shapes.at(slot_number); | ||
if (shape.dims() > 8) { | ||
LOG(ERROR) << "Tensor rank is greater than 8 for " << node_name | ||
<< " at input slot " << slot_number; | ||
return tensorflow::errors::OutOfRange( | ||
"Input tensor rank is greater than 8"); | ||
auto status = ValidateInputProperties( | ||
shape, node_def.attr().at("dtype").type(), &dtype); | ||
if (!status.ok()) { | ||
const string error_message = | ||
StrCat("Validation failed for ", node_name, " and input slot ", | ||
slot_number, ": ", status.error_message()); | ||
LOG(WARNING) << error_message; | ||
return Status(status.code(), error_message); | ||
} | ||
if (VLOG_IS_ON(1)) { | ||
string dim_str("dims="); | ||
|
@@ -2689,10 +2739,10 @@ tensorflow::Status ConvertGraphDefToEngine( | |
} else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) && | ||
(node_def.op() == "Identity")) { | ||
int32 slot_number = -1; | ||
if (!tensorflow::strings::safe_strto32(node_name.c_str() + 9, | ||
&slot_number)) { | ||
LOG(ERROR) << "Failed to parse slot number from " << node_name | ||
<< " +9=" << node_name.c_str() + 9; | ||
if (!tensorflow::strings::safe_strto32( | ||
node_name.c_str() + strlen(kOutputPHName), &slot_number)) { | ||
return tensorflow::errors::InvalidArgument( | ||
"Failed to parse slot number from ", node_name); | ||
} | ||
if (output_tensors.size() <= slot_number) { | ||
output_tensors.resize(slot_number + 1); | ||
|
@@ -2751,38 +2801,20 @@ tensorflow::Status ConvertSegmentToGraphDef( | |
"Cannot find node with id ", connection.outside_id, " in the graph."); | ||
} | ||
// Updates the shape and data types of input/output connections. | ||
tensorflow::DataType input_type = tensorflow::DT_FLOAT; | ||
tensorflow::DataType dtype; | ||
tensorflow::PartialTensorShape partial_shape; | ||
if (connection.is_input_edge) { | ||
if (graph_properties.HasOutputProperties(connection.outside_node_name)) { | ||
auto output_params = | ||
graph_properties.GetOutputProperties(connection.outside_node_name); | ||
auto out_shape = output_params.at(connection.outside_port); | ||
input_type = out_shape.dtype(); | ||
std::vector<tensorflow::int64> dims; | ||
partial_shape = out_shape.shape(); | ||
connection.outside_shape = partial_shape; | ||
} else { | ||
VLOG(0) << "Unknown output shape" << outside_node->name(); | ||
input_type = graph->FindNodeId(connection.outside_id) | ||
->output_type(connection.outside_port); | ||
} | ||
connection.connection_type = input_type; | ||
|
||
} else { // output edge | ||
if (graph_properties.HasInputProperties(connection.outside_node_name)) { | ||
auto input_params = | ||
graph_properties.GetInputProperties(connection.outside_node_name); | ||
auto in_shape = input_params.at(connection.outside_port); | ||
input_type = in_shape.dtype(); | ||
partial_shape = in_shape.shape(); | ||
connection.inside_shape = partial_shape; | ||
} else { | ||
input_type = graph->FindNodeId(connection.inside_id) | ||
->output_type(connection.outside_port); | ||
} | ||
connection.connection_type = input_type; | ||
GetInputProperties(graph_properties, | ||
graph->FindNodeId(connection.outside_id), | ||
connection.outside_port, &partial_shape, &dtype); | ||
|
||
} else { | ||
GetOutputProperties(graph_properties, | ||
graph->FindNodeId(connection.outside_id), | ||
connection.outside_port, &partial_shape, &dtype); | ||
} | ||
connection.outside_shape = partial_shape; | ||
connection.connection_type = dtype; | ||
|
||
// Add dummy input/output nodes to the segment graphdef. | ||
if (connection.is_input_edge) { | ||
|
@@ -2798,7 +2830,7 @@ tensorflow::Status ConvertSegmentToGraphDef( | |
auto seg_node = segment_def->add_node(); | ||
tensorflow::NodeDefBuilder builder(node_name, "Placeholder"); | ||
auto status = builder.Attr("shape", partial_shape) | ||
.Attr("dtype", input_type) | ||
.Attr("dtype", dtype) | ||
.Finalize(seg_node); | ||
VLOG(1) << "Constructing input " << node_name << " for the edge " | ||
<< connection.outside_node_name << ":" << connection.outside_port | ||
|
@@ -2816,7 +2848,7 @@ tensorflow::Status ConvertSegmentToGraphDef( | |
marker_nodes.insert(node_name); | ||
auto seg_node = segment_def->add_node(); | ||
tensorflow::NodeDefBuilder builder(node_name, "Identity"); | ||
auto status = builder.Input(connection.inside_node_name, 0, input_type) | ||
auto status = builder.Input(connection.inside_node_name, 0, dtype) | ||
.Finalize(seg_node); | ||
VLOG(1) << "Constructing output " << node_name << " for the edge " | ||
<< connection.inside_node_name << ":" << connection.inside_port | ||
|
@@ -2854,6 +2886,38 @@ tensorflow::Status ConvertSegmentToGraphDef( | |
return tensorflow::Status::OK(); | ||
} | ||
|
||
bool InputEdgeValidator::operator()(const tensorflow::Edge* in_edge) const { | ||
if (in_edge->IsControlEdge()) return true; | ||
PartialTensorShape shape; | ||
tensorflow::DataType dtype; | ||
GetInputProperties(graph_properties_, in_edge->src(), in_edge->src_output(), | ||
&shape, &dtype); | ||
nvinfer1::DataType trt_dtype; | ||
Status status = ValidateInputProperties(shape, dtype, &trt_dtype); | ||
if (!status.ok()) { | ||
VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() | ||
<< ": " << status; | ||
return false; | ||
} | ||
if (shape.dims() < 3 && in_edge->src()->type_string() != "Const") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shape.dims() < 3? TRT 3.0 requires input to be of rank 4 (batch + 3 dimension). This requirement has been removed in TRT 4.0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline, I'll leave this as is since there are more places to change besides this one regarding the 3 dims requirement. And we can issue another PR to fix the problem. |
||
VLOG(2) << "--> Need to remove input node " << in_edge->dst()->name() | ||
<< " which has an input at port " << in_edge->dst_input() | ||
<< " with #dim<3 and is not a const: " << shape; | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
bool OutputEdgeValidator::operator()(const tensorflow::Edge* out_edge) const { | ||
if (out_edge->IsControlEdge()) return true; | ||
if (out_edge->src()->type_string() == "Const") { | ||
VLOG(2) << "--> Need to remove output node " << out_edge->src()->name() | ||
<< " which is a Const."; | ||
return false; | ||
} | ||
return true; | ||
} | ||
|
||
} // namespace convert | ||
} // namespace tensorrt | ||
} // namespace tensorflow | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TRT supports tensor with 8 dimension EXCLUDING batch dimension. So we should be able to support shape.dims() <= 9.
It is possible that the old code got the check condition wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out. I'll fix this before this is merged.