Skip to content
Permalink
Browse files Browse the repository at this point in the history
Make IsSimplifiableReshape return Status instead of bool.
This is to allow remove `CHECK`-fails in subsequent commits.

PiperOrigin-RevId: 409160987
Change-Id: I3f050218a3832271395c4372a0b8ea05f1c03d80
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Nov 11, 2021
1 parent 71399f9 commit ebc1a2f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
28 changes: 19 additions & 9 deletions tensorflow/core/grappler/optimizers/constant_folding.cc
Expand Up @@ -1684,15 +1684,17 @@ Status ConstantFolding::FoldGraph(
return Status::OK();
}

bool ConstantFolding::IsSimplifiableReshape(
Status ConstantFolding::IsSimplifiableReshape(
const NodeDef& node, const GraphProperties& properties) const {
if (!IsReshape(node)) {
return false;
return errors::Internal("Node ", node.name(), " is not a Reshape node");
}
CHECK_LE(2, node.input_size());
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
if (!IsReallyConstant(*new_shape)) {
return false;
return errors::Internal("Node ", node.name(), " has shape ",
new_shape->DebugString(),
" which is not a constant");
}
TensorVector outputs;
auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
Expand All @@ -1703,22 +1705,25 @@ bool ConstantFolding::IsSimplifiableReshape(

Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
if (!s.ok()) {
return false;
return errors::Internal("Could not evaluate node ", node.name());
}
CHECK_EQ(1, outputs.size());

const std::vector<OpInfo::TensorProperties>& props =
properties.GetInputProperties(node.name());
if (props.empty()) {
return false;
return errors::Internal("Node ", node.name(), " has no properties");
}
const OpInfo::TensorProperties& prop = props[0];
if (prop.dtype() == DT_INVALID) {
return false;
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with invalid dtype");
}
const PartialTensorShape shape(prop.shape());
if (!shape.IsFullyDefined()) {
return false;
return errors::Internal("Node ", node.name(), " has property ",
prop.DebugString(), " with shape ",
shape.DebugString(), " which is not fully defined");
}

PartialTensorShape new_dims;
Expand All @@ -1738,7 +1743,12 @@ bool ConstantFolding::IsSimplifiableReshape(
TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
}

return shape.IsCompatibleWith(new_dims);
if (!shape.IsCompatibleWith(new_dims)) {
return errors::Internal("Expected shape ", shape.DebugString(),
"to be compatible with ", new_dims.DebugString());
}

return Status::OK();
}

#define IS_VALUE_CASE(DTYPE, VALUE) \
Expand Down Expand Up @@ -2925,7 +2935,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
bool use_shape_info, NodeDef* node) {
if (!use_shape_info || node->attr().count("T") == 0 ||
!IsSimplifiableReshape(*node, properties)) {
!IsSimplifiableReshape(*node, properties).ok()) {
return false;
}
DataType output_type = node->attr().at("T").type();
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/grappler/optimizers/constant_folding.h
Expand Up @@ -129,8 +129,8 @@ class ConstantFolding : public GraphOptimizer {
Status FoldGraph(const GraphProperties& properties, GraphDef* output,
absl::flat_hash_set<string>* nodes_to_not_simplify);

bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
Status IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
Status SimplifyGraph(GraphDef* optimized_graph, GraphProperties* properties,
absl::flat_hash_set<string>* nodes_to_not_simplify);
Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
Expand Down

0 comments on commit ebc1a2f

Please sign in to comment.