From ca4e75a6f84ca51b9cef0a4aec7816d38512a8e2 Mon Sep 17 00:00:00 2001 From: Daniel Rasmussen Date: Tue, 30 Oct 2018 18:10:53 -0300 Subject: [PATCH] Fix concat optimization infinite loop --- .../grappler/optimizers/constant_folding.cc | 10 +++--- .../optimizers/constant_folding_test.cc | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 3882e3b3a9a0fa..803cc80cd44387 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -3022,11 +3022,11 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph, for (auto interval : constant_input_runs) { // Push the constant inputs in the interval to a child node than can be // constant folded. - const string new_node_name = OptimizedNodeName( - *node, strings::StrCat("_partial_split_", interval.first)); - if (node_map_->NodeExists(new_node_name)) { - break; - } + string new_node_name = OptimizedNodeName(*node, "_partial_split"); + do { + new_node_name += strings::StrCat("_", interval.first); + } while (node_map_->NodeExists(new_node_name)); + NodeDef* added_node = optimized_graph->add_node(); *added_node = *node; added_node->set_name(new_node_name); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 192f48272f9ed0..76d2555c03752f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2154,6 +2154,40 @@ TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) { CompareGraphs(want, got); } +TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) { + Scope scope = Scope::NewRootScope(); + Output c1 = ops::Const(scope.WithOpName("c1"), 1.0f, {2, 2}); + Output c2 = ops::Const(scope.WithOpName("c2"), 2.0f, {2, 2}); + Output c3 = ops::Const(scope.WithOpName("c3"), 3.0f, {2, 2}); + Output c4 = ops::Const(scope.WithOpName("c4"), 4.0f, {2, 2}); + Output ph = ops::Placeholder(scope.WithOpName("ph"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({2, 2}))); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + + ops::Concat concat1(scope.WithOpName("concat1"), {c1, c2, ph}, axis); + ops::Concat concat2(scope.WithOpName("concat2"), {c3, c4, Output(concat1)}, + axis); + + GrapplerItem item; + item.fetch = {"concat2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want); + AddNode("axis", "Const", {}, {}, &want); + AddNode("ph", "Placeholder", {}, {}, &want); + AddNode("concat2", "ConcatV2", + {"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {}, + &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, PaddingWithZeroSize) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope();