Skip to content

Commit

Permalink
Merge pull request #54096 from tensorflow/cherrypick-ebc1a2ffe5a7573d…
Browse files Browse the repository at this point in the history
…905e99bd0ee3568ee07c12c1-on-r2.6

Make `IsSimplifiableReshape` return `Status` instead of `bool`.
  • Loading branch information
mihaimaruseac committed Jan 26, 2022
2 parents e6c146a + b28babd commit 263a6bd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
28 changes: 19 additions & 9 deletions tensorflow/core/grappler/optimizers/constant_folding.cc
Expand Up @@ -1681,10 +1681,10 @@ Status ConstantFolding::FoldGraph(
return Status::OK();
}

bool ConstantFolding::IsSimplifiableReshape(
Status ConstantFolding::IsSimplifiableReshape(
const NodeDef& node, const GraphProperties& properties) const {
if (!IsReshape(node)) {
return false;
return errors::Internal("Node ", node.name(), " is not a Reshape node");
}
if (2 > node.input_size()) {
return errors::Internal("Node ", node.name(),
Expand All @@ -1693,7 +1693,9 @@ bool ConstantFolding::IsSimplifiableReshape(
}
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
if (!IsReallyConstant(*new_shape)) {
return false;
return errors::Internal("Node ", node.name(), " has shape ",
new_shape->DebugString(),
" which is not a constant");
}
TensorVector outputs;
auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
Expand All @@ -1704,7 +1706,7 @@ bool ConstantFolding::IsSimplifiableReshape(

Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
if (!s.ok()) {
return false;
return errors::Internal("Could not evaluate node ", node.name());
}
if (outputs.size() != 1) {
return errors::Internal("Node ", node.name(),
Expand All @@ -1715,15 +1717,18 @@ bool ConstantFolding::IsSimplifiableReshape(
const std::vector<OpInfo::TensorProperties>& props =
properties.GetInputProperties(node.name());
if (props.empty()) {
return false;
return errors::Internal("Node ", node.name(), " has no properties");
}
const OpInfo::TensorProperties& prop = props[0];
if (prop.dtype() == DT_INVALID) {
return false;
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with invalid dtype");
}
const PartialTensorShape shape(prop.shape());
if (!shape.IsFullyDefined()) {
return false;
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with shape ",
shape.DebugString(), " which is not fully defined");
}

PartialTensorShape new_dims;
Expand All @@ -1745,7 +1750,12 @@ bool ConstantFolding::IsSimplifiableReshape(
if (!s.ok()) return s;
}

return shape.IsCompatibleWith(new_dims);
if (!shape.IsCompatibleWith(new_dims)) {
return errors::Internal("Expected shape ", shape.DebugString(),
"to be compatible with ", new_dims.DebugString());
}

return Status::OK();
}

#define IS_VALUE_CASE(DTYPE, VALUE) \
Expand Down Expand Up @@ -2931,7 +2941,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
bool use_shape_info, NodeDef* node) {
if (!use_shape_info || node->attr().count("T") == 0 ||
!IsSimplifiableReshape(*node, properties)) {
!IsSimplifiableReshape(*node, properties).ok()) {
return false;
}
DataType output_type = node->attr().at("T").type();
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/grappler/optimizers/constant_folding.h
Expand Up @@ -129,8 +129,8 @@ class ConstantFolding : public GraphOptimizer {
Status FoldGraph(const GraphProperties& properties, GraphDef* output,
absl::flat_hash_set<string>* nodes_to_not_simplify);

bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
Status IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph,
GraphProperties* properties,
absl::flat_hash_set<string>* nodes_to_not_simplify);
Expand Down

0 comments on commit 263a6bd

Please sign in to comment.