From 5e10d84f0ae98bbbd910ed1b83dbf5e80928df9b Mon Sep 17 00:00:00 2001 From: zilinzhu Date: Tue, 2 Mar 2021 16:04:17 +0800 Subject: [PATCH] add ds.shard(1, 0) to noop_elimination --- .../optimizers/data/noop_elimination.cc | 9 +- .../optimizers/data/noop_elimination_test.cc | 161 +++++++++++------- .../optimization/noop_elimination_test.py | 2 + 3 files changed, 114 insertions(+), 58 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index 8d84162d900bb2..abe599afd80e99 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -67,6 +67,13 @@ bool IsPrefetchZero(const NodeDef& prefetch_node, return IsConstNodeWithValue(*graph.GetNode(prefetch_node.input(1)), 0); } +bool IsShardOne(const NodeDef& shard_node, + const MutableGraphView& graph) { + if (shard_node.op() != "ShardDataset") return false; + // We are looking only for shard(0) nodes. + return IsConstNodeWithValue(*graph.GetNode(shard_node.input(1)), 1); +} + bool IsOutputIdentityOfInput(const FunctionDef& fdef, const string& output_arg, const string& input_arg) { if (!fdef.ret().contains(output_arg)) { @@ -131,7 +138,7 @@ bool IsMapIdentity(const NodeDef& map_node, const MutableGraphView& graph) { bool IsNoOp(const NodeDef& node, const MutableGraphView& graph) { return IsTakeAll(node, graph) || IsSkipNone(node, graph) || IsRepeatOne(node, graph) || IsPrefetchZero(node, graph) || - IsMapIdentity(node, graph); + IsShardOne(node, graph) || IsMapIdentity(node, graph); } } // namespace diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc index 323bb1d599f36e..dff90670f71a09 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -35,19 +35,35 @@ std::vector> GetCommonAttributes() { return commonAttributes; } -NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node, - MutableGraphView *graph) { - NodeDef *node_count = graph_utils::AddScalarConstNode(count, graph); - return graph_utils::AddNode("", node_type, - {std::move(input_node), node_count->name()}, +NodeDef *MakeNode(StringPiece node_type, std::vector params, + string input_node, MutableGraphView *graph) { + std::vector node_params; + for (int param : params) { + node_params.push_back( + graph_utils::AddScalarConstNode(param, graph)); + } + std::vector inputs = {input_node}; + for (int i = 0; i < node_params.size(); i++) { + inputs.push_back(node_params[i]->name()); + } + return graph_utils::AddNode("", node_type, inputs, GetCommonAttributes(), graph); } -NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node, - MutableGraphView *graph) { - NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph); - return graph_utils::AddNode("", node_type, - {std::move(input_node), node_count->name()}, +NodeDef *MakeNonConstNode( + StringPiece node_type, std::vector param_dtypes, + string input_node, MutableGraphView *graph) { + std::vector node_params; + for (DataType dtype : param_dtypes) { + node_params.push_back( + graph_utils::AddScalarPlaceholder(dtype, graph)); + } + std::vector inputs = {input_node}; + for (int i = 0; i < node_params.size(); i++) { + inputs.push_back(node_params[i]->name()); + } + + return graph_utils::AddNode("", node_type, inputs, GetCommonAttributes(), graph); } @@ -72,7 +88,7 @@ NodeDef *MakeRangeNode(MutableGraphView *graph) { } struct NoOpLastEliminationTest - : ::testing::TestWithParam> {}; + : ::testing::TestWithParam, bool>> {}; // This test checks whether the no-op elimination correctly handles // transformations at the end of the pipeline. @@ -81,13 +97,13 @@ TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) { MutableGraphView graph(&item.graph); const string &node_type = std::get<0>(GetParam()); - const int node_count = std::get<1>(GetParam()); + const std::vector node_params = std::get<1>(GetParam()); const bool should_keep_node = std::get<2>(GetParam()); NodeDef *range_node = MakeRangeNode(&graph); NodeDef *node = - MakeUnaryNode(node_type, node_count, range_node->name(), &graph); + MakeNode(node_type, node_params, range_node->name(), &graph); NoOpElimination optimizer; GraphDef output; @@ -99,20 +115,23 @@ TEST_P(NoOpLastEliminationTest, EliminateLastNoOpNode) { INSTANTIATE_TEST_CASE_P( BasicRemovalTest, NoOpLastEliminationTest, - ::testing::Values(std::make_tuple("TakeDataset", -3, false), - std::make_tuple("TakeDataset", -1, false), - std::make_tuple("TakeDataset", 0, true), - std::make_tuple("TakeDataset", 3, true), - std::make_tuple("SkipDataset", -1, true), - std::make_tuple("SkipDataset", 0, false), - std::make_tuple("SkipDataset", 3, true), - std::make_tuple("PrefetchDataset", 0, false), - std::make_tuple("PrefetchDataset", 1, true), - std::make_tuple("RepeatDataset", 1, false), - std::make_tuple("RepeatDataset", 2, true))); + ::testing::Values( + std::make_tuple("TakeDataset", std::vector({-3}), false), + std::make_tuple("TakeDataset", std::vector({-1}), false), + std::make_tuple("TakeDataset", std::vector({0}), true), + std::make_tuple("TakeDataset", std::vector({3}), true), + std::make_tuple("SkipDataset", std::vector({-1}), true), + std::make_tuple("SkipDataset", std::vector({0}), false), + std::make_tuple("SkipDataset", std::vector({3}), true), + std::make_tuple("PrefetchDataset", std::vector({0}), false), + std::make_tuple("PrefetchDataset", std::vector({1}), true), + std::make_tuple("RepeatDataset", std::vector({1}), false), + std::make_tuple("RepeatDataset", std::vector({2}), true), + std::make_tuple("ShardDataset", std::vector({1, 0}), false), + std::make_tuple("ShardDataset", std::vector({2, 0}), true))); struct NoOpMiddleEliminationTest - : ::testing::TestWithParam> {}; + : ::testing::TestWithParam, bool>> {}; // This test checks whether the no-op elimination correctly handles // transformations int the middle of the pipeline. @@ -121,13 +140,13 @@ TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) { MutableGraphView graph(&item.graph); const string &node_type = std::get<0>(GetParam()); - const int node_count = std::get<1>(GetParam()); + const std::vector node_params = std::get<1>(GetParam()); const bool should_keep_node = std::get<2>(GetParam()); NodeDef *range_node = MakeRangeNode(&graph); NodeDef *node = - MakeUnaryNode(node_type, node_count, range_node->name(), &graph); + MakeNode(node_type, node_params, range_node->name(), &graph); NodeDef *cache_node = MakeCacheNode(node->name(), &graph); NoOpElimination optimizer; @@ -149,19 +168,23 @@ TEST_P(NoOpMiddleEliminationTest, EliminateMiddleNoOpNode) { INSTANTIATE_TEST_CASE_P( BasicRemovalTest, NoOpMiddleEliminationTest, - ::testing::Values(std::make_tuple("TakeDataset", -1, false), - std::make_tuple("TakeDataset", -3, false), - std::make_tuple("TakeDataset", 0, true), - std::make_tuple("TakeDataset", 3, true), - std::make_tuple("SkipDataset", -1, true), - std::make_tuple("SkipDataset", 0, false), - std::make_tuple("SkipDataset", 3, true), - std::make_tuple("PrefetchDataset", 0, false), - std::make_tuple("PrefetchDataset", 1, true), - std::make_tuple("RepeatDataset", 1, false), - std::make_tuple("RepeatDataset", 2, true))); - -using NodesTypes = std::tuple, std::pair>; + ::testing::Values( + std::make_tuple("TakeDataset", std::vector({-1}), false), + std::make_tuple("TakeDataset", std::vector({-3}), false), + std::make_tuple("TakeDataset", std::vector({0}), true), + std::make_tuple("TakeDataset", std::vector({3}), true), + std::make_tuple("SkipDataset", std::vector({-1}), true), + std::make_tuple("SkipDataset", std::vector({0}), false), + std::make_tuple("SkipDataset", std::vector({3}), true), + std::make_tuple("PrefetchDataset", std::vector({0}), false), + std::make_tuple("PrefetchDataset", std::vector({1}), true), + std::make_tuple("RepeatDataset", std::vector({1}), false), + std::make_tuple("RepeatDataset", std::vector({2}), true), + std::make_tuple("ShardDataset", std::vector({1, 0}), false), + std::make_tuple("ShardDataset", std::vector({2, 0}), true))); + +using NodesTypes = std::tuple>, + std::pair>>; struct NoOpMultipleEliminationTest : ::testing::TestWithParam {}; // This test checks whether the no-op elimination correctly removes @@ -172,7 +195,7 @@ TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) { static_assert(std::tuple_size::value == 2, "Make sure to include everything in the test"); - const std::vector> noop_nodes = { + const std::vector>> noop_nodes = { std::get<0>(GetParam()), std::get<1>(GetParam())}; NodeDef *range_node = MakeRangeNode(&graph); @@ -182,8 +205,8 @@ TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) { nodes_to_remove.reserve(noop_nodes.size()); for (const auto &noop_node : noop_nodes) { - NodeDef *node = MakeUnaryNode(noop_node.first, noop_node.second, - previous->name(), &graph); + NodeDef *node = MakeNode(noop_node.first, noop_node.second, + previous->name(), &graph); nodes_to_remove.push_back(node->name()); previous = node; } @@ -207,21 +230,28 @@ TEST_P(NoOpMultipleEliminationTest, EliminateMultipleNoOpNode) { EXPECT_EQ(cache_node_out.input(0), range_node->name()); } -const auto *const kTakeNode = new std::pair{"TakeDataset", -1}; -const auto *const kSkipNode = new std::pair{"SkipDataset", 0}; -const auto *const kRepeatNode = new std::pair{"RepeatDataset", 1}; +const auto *const kTakeNode = + new std::pair>{"TakeDataset", {-1}}; +const auto *const kSkipNode = + new std::pair>{"SkipDataset", {0}}; +const auto *const kRepeatNode = + new std::pair>{"RepeatDataset", {1}}; const auto *const kPrefetchNode = - new std::pair{"PrefetchDataset", 0}; + new std::pair>{"PrefetchDataset", {0}}; +const auto *const kShardNode = + new std::pair>{"ShardDataset", {1, 0}}; INSTANTIATE_TEST_CASE_P( BasicRemovalTest, NoOpMultipleEliminationTest, ::testing::Combine(::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode, - *kPrefetchNode), + *kPrefetchNode, *kShardNode), ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode, - *kPrefetchNode))); + *kPrefetchNode, *kShardNode))); struct NoOpPlaceholdersTest - : ::testing::TestWithParam> {}; + : ::testing::TestWithParam< + std::tuple>, + std::pair>>> {}; TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { GrapplerItem item; @@ -229,15 +259,16 @@ TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { static_assert(std::tuple_size::value == 2, "Make sure to include everything in the test"); - const std::vector noop_nodes = {std::get<0>(GetParam()), - std::get<1>(GetParam())}; + const std::vector>> noop_nodes = + {std::get<0>(GetParam()), std::get<1>(GetParam())}; NodeDef *range_node = MakeRangeNode(&graph); std::vector nodes_to_keep; nodes_to_keep.reserve(noop_nodes.size()); NodeDef *previous = range_node; for (const auto &noop_node : noop_nodes) { - NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph); + NodeDef *node = MakeNonConstNode(noop_node.first, noop_node.second, + previous->name(), &graph); nodes_to_keep.push_back(node->name()); previous = node; } @@ -249,12 +280,28 @@ TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output)); } +const auto *const kNonConstTakeNode = + new std::pair>{"TakeDataset", {DT_INT32}}; +const auto *const kNonConstSkipNode = + new std::pair>{"SkipDataset", {DT_INT32}}; +const auto *const kNonConstRepeatNode = + new std::pair>{"RepeatDataset", {DT_INT32}}; +const auto *const kNonConstPrefetchNode = + new std::pair>{"PrefetchDataset", + {DT_INT32}}; +const auto *const kNonConstShardNode = + new std::pair>{"ShardDataset", + {DT_INT32, DT_INT32}}; + INSTANTIATE_TEST_CASE_P( DoNotRemovePlaceholders, NoOpPlaceholdersTest, - ::testing::Combine(::testing::Values("TakeDataset", "SkipDataset", - "RepeatDataset", "PrefetchDataset"), - ::testing::Values("TakeDataset", "SkipDataset", - "RepeatDataset", "PrefetchDataset"))); + ::testing::Combine( + ::testing::Values(*kNonConstTakeNode, *kNonConstSkipNode, + *kNonConstRepeatNode, *kNonConstPrefetchNode, + *kNonConstShardNode), + ::testing::Values(*kNonConstTakeNode, *kNonConstSkipNode, + *kNonConstRepeatNode, *kNonConstPrefetchNode, + *kNonConstShardNode))); } // namespace } // namespace grappler diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py index e8fdf5f2e24cd4..c0609103f15e37 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py @@ -73,6 +73,8 @@ def apply_map_with_multiple_components(ds): ("PMapNonIdentity", lambda ds: ds.map(lambda x: x * 2, num_parallel_calls=2), parallel_map_name), + ("Shard1", lambda ds: ds.shard(1, 0), None), + ("ShardN", lambda ds: ds.shard(2, 0), "Shard"), ] def reduce_fn(result, case):