From b28babd26787127a21d3d8214f9318ce59a2a2c4 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Thu, 11 Nov 2021 08:58:30 -0800 Subject: [PATCH] Make `IsSimplifiableReshape` return `Status` instead of `bool`. This is to allow remove `CHECK`-fails in subsequent commits. PiperOrigin-RevId: 409160987 Change-Id: I3f050218a3832271395c4372a0b8ea05f1c03d80 --- .../grappler/optimizers/constant_folding.cc | 28 +++++++++++++------ .../grappler/optimizers/constant_folding.h | 4 +-- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 840e138d69f113..24d30eb998bf23 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1681,10 +1681,10 @@ Status ConstantFolding::FoldGraph( return Status::OK(); } -bool ConstantFolding::IsSimplifiableReshape( +Status ConstantFolding::IsSimplifiableReshape( const NodeDef& node, const GraphProperties& properties) const { if (!IsReshape(node)) { - return false; + return errors::Internal("Node ", node.name(), " is not a Reshape node"); } if (2 > node.input_size()) { return errors::Internal("Node ", node.name(), @@ -1693,7 +1693,9 @@ bool ConstantFolding::IsSimplifiableReshape( } const NodeDef* new_shape = node_map_->GetNode(node.input(1)); if (!IsReallyConstant(*new_shape)) { - return false; + return errors::Internal("Node ", node.name(), " has shape ", + new_shape->DebugString(), + " which is not a constant"); } TensorVector outputs; auto outputs_cleanup = gtl::MakeCleanup([&outputs] { @@ -1704,7 +1706,7 @@ bool ConstantFolding::IsSimplifiableReshape( Status s = EvaluateNode(*new_shape, TensorVector(), &outputs); if (!s.ok()) { - return false; + return errors::Internal("Could not evaluate node ", node.name()); } if (outputs.size() != 1) { return errors::Internal("Node ", node.name(), @@ -1715,15 +1717,18 @@ bool ConstantFolding::IsSimplifiableReshape( const std::vector& props = properties.GetInputProperties(node.name()); if (props.empty()) { - return false; + return errors::Internal("Node ", node.name(), " has no properties"); } const OpInfo::TensorProperties& prop = props[0]; if (prop.dtype() == DT_INVALID) { - return false; + return errors::Internal("Node ", node.name(), " has property ", + prop.DebugString(), " with invalid dtype"); } const PartialTensorShape shape(prop.shape()); if (!shape.IsFullyDefined()) { - return false; + return errors::Internal("Node ", node.name(), " has property ", + prop.DebugString(), " with shape ", + shape.DebugString(), " which is not fully defined"); } PartialTensorShape new_dims; @@ -1745,7 +1750,12 @@ bool ConstantFolding::IsSimplifiableReshape( if (!s.ok()) return s; } - return shape.IsCompatibleWith(new_dims); + if (!shape.IsCompatibleWith(new_dims)) { + return errors::Internal("Expected shape ", shape.DebugString(), + "to be compatible with ", new_dims.DebugString()); + } + + return Status::OK(); } #define IS_VALUE_CASE(DTYPE, VALUE) \ @@ -2931,7 +2941,7 @@ bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph, bool ConstantFolding::SimplifyReshape(const GraphProperties& properties, bool use_shape_info, NodeDef* node) { if (!use_shape_info || node->attr().count("T") == 0 || - !IsSimplifiableReshape(*node, properties)) { + !IsSimplifiableReshape(*node, properties).ok()) { return false; } DataType output_type = node->attr().at("T").type(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 29b9f3b270af5f..538594e540097a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -129,8 +129,8 @@ class ConstantFolding : public GraphOptimizer { Status FoldGraph(const GraphProperties& properties, GraphDef* output, absl::flat_hash_set* nodes_to_not_simplify); - bool IsSimplifiableReshape(const NodeDef& node, - const GraphProperties& properties) const; + Status IsSimplifiableReshape(const NodeDef& node, + const GraphProperties& properties) const; Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties, absl::flat_hash_set* nodes_to_not_simplify);