Skip to content

Commit

Permalink
Fix RemoveUnusedNodes generating invalid graphs for PlaceholderWithDe…
Browse files Browse the repository at this point in the history
…fault inputs

PiperOrigin-RevId: 199776409
  • Loading branch information
tensorflower-gardener committed Jun 8, 2018
1 parent 16c1d25 commit 1c241ba
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 46 deletions.
26 changes: 26 additions & 0 deletions tensorflow/tools/graph_transforms/fold_constants_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,31 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
return Status::OK();
}

Status RewriteInputsAsPlaceholders(const TransformFuncContext& context,
GraphDef* graph_def) {
std::unordered_set<string> input_names;
for (const string& input_name : context.input_names) {
input_names.insert(ParseTensorName(input_name).first.ToString());
}

for (NodeDef& node : *graph_def->mutable_node()) {
if (input_names.find(node.name()) == input_names.end()) {
continue;
}
if (node.op() == "PlaceholderWithDefault") {
node.set_op("Placeholder");
node.clear_input();
} else if (node.op() != "Placeholder") {
return errors::InvalidArgument(
"Input '", node.name(),
"' was expected to be a Placeholder or PlaceholderWithDefault op, "
"but was ",
node.op());
}
}
return Status::OK();
}

Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
Expand Down Expand Up @@ -165,6 +190,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
input_graph_def,
[&](const NodeDef& node) { return used_nodes.count(node.name()) > 0; },
output_graph_def);
TF_RETURN_IF_ERROR(RewriteInputsAsPlaceholders(context, output_graph_def));

return Status::OK();
}
Expand Down
46 changes: 0 additions & 46 deletions tensorflow/tools/graph_transforms/fold_constants_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,48 +330,6 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("unused"));
}

void TestRemoveUnusedNodesMultipleOutputs() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto root = tensorflow::Scope::NewRootScope();

// a b
// \ /
// shape_n
// \ /
// c
auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);

GraphDef graph_def;
TF_ASSERT_OK(root.ToGraphDef(&graph_def));
GraphDef result_graph_def;
TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));

// Only one output of shape_n node is fed input. Hence the graph search
// should propagate to inputs of shape_n. Nothing to remove here.
std::map<string, const NodeDef*> node_map;
graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
EXPECT_EQ(1, node_map.count("a"));
EXPECT_EQ(1, node_map.count("b"));
EXPECT_EQ(1, node_map.count("c"));

result_graph_def.Clear();
TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
&result_graph_def));

// Both outputs of shape_n node are fed inputs. shape_n does not function
// and inputs to shape_n should be removed.
node_map.clear();
graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
EXPECT_EQ(0, node_map.count("a"));
EXPECT_EQ(0, node_map.count("b"));
EXPECT_EQ(1, node_map.count("c"));
}

void TestMaxConstantSizeInBytes() {
auto root = tensorflow::Scope::NewRootScope();

Expand Down Expand Up @@ -431,10 +389,6 @@ TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {

TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }

TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
TestRemoveUnusedNodesMultipleOutputs();
}

TEST_F(ConstantFoldingTest, TestMaxConstantSizeInBytes) {
TestMaxConstantSizeInBytes();
}
Expand Down

0 comments on commit 1c241ba

Please sign in to comment.