diff --git a/ngraph_bridge/ngraph_assign_clusters.cc b/ngraph_bridge/ngraph_assign_clusters.cc index 279a7a1c3..5cd9bd4e7 100644 --- a/ngraph_bridge/ngraph_assign_clusters.cc +++ b/ngraph_bridge/ngraph_assign_clusters.cc @@ -817,6 +817,10 @@ Status GetNodeCluster(const Node* node, int* cluster) { return s; } +void ResetAssignClusters(Graph* graph) { + ClearAttribute(graph, {"_ngraph_cluster"}); +} + } // namespace ngraph_bridge } // namespace tensorflow diff --git a/ngraph_bridge/ngraph_assign_clusters.h b/ngraph_bridge/ngraph_assign_clusters.h index 33854aa65..1fbba02df 100644 --- a/ngraph_bridge/ngraph_assign_clusters.h +++ b/ngraph_bridge/ngraph_assign_clusters.h @@ -25,6 +25,8 @@ namespace tensorflow { namespace ngraph_bridge { Status AssignClusters(Graph* graph); +// reset the effect of AssignClusters +void ResetAssignClusters(Graph* graph); Status GetNodeCluster(const Node* node, int* cluster); } // namespace ngraph_bridge diff --git a/ngraph_bridge/ngraph_mark_for_clustering.cc b/ngraph_bridge/ngraph_mark_for_clustering.cc index 79f0a078d..37ef9ce57 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.cc +++ b/ngraph_bridge/ngraph_mark_for_clustering.cc @@ -1271,6 +1271,11 @@ void SetNodeBackend(Node* node, const string& backend_name) { node->AddAttr("_ngraph_backend", backend_name); } +void ResetMarkForClustering(Graph* graph) { + ClearAttribute(graph, {"_ngraph_marked_for_clustering", "_ngraph_backend", + "_ngraph_static_inputs"}); +} + } // namespace ngraph_bridge } // namespace tensorflow diff --git a/ngraph_bridge/ngraph_mark_for_clustering.h b/ngraph_bridge/ngraph_mark_for_clustering.h index 41a8d02a0..ea4739741 100644 --- a/ngraph_bridge/ngraph_mark_for_clustering.h +++ b/ngraph_bridge/ngraph_mark_for_clustering.h @@ -28,6 +28,8 @@ namespace ngraph_bridge { Status MarkForClustering(Graph* graph, std::set skip_these_nodes, const string& current_backend); +// remove marking, backend and static input nodes attributes +void ResetMarkForClustering(Graph* graph); Status IsSupportedByBackend( const Node* node, const ngraph::runtime::Backend* op_backend, std::map>>& diff --git a/ngraph_bridge/ngraph_utils.cc b/ngraph_bridge/ngraph_utils.cc index e3cad7be8..4c0f69416 100644 --- a/ngraph_bridge/ngraph_utils.cc +++ b/ngraph_bridge/ngraph_utils.cc @@ -600,6 +600,15 @@ bool IsProcessedByNgraphPass(Graph* g) { return false; } +void ClearAttribute(Graph* g, + const std::set& attributes_to_be_cleared) { + for (auto node : g->nodes()) { + for (const auto& attr : attributes_to_be_cleared) { + node->ClearAttr(attr); + } + } +} + } // namespace ngraph_bridge } // namespace tensorflow diff --git a/ngraph_bridge/ngraph_utils.h b/ngraph_bridge/ngraph_utils.h index 4c94accb9..84cd624de 100644 --- a/ngraph_bridge/ngraph_utils.h +++ b/ngraph_bridge/ngraph_utils.h @@ -348,6 +348,8 @@ std::string GraphFilenamePrefix(std::string, int); std::string GraphFilenamePrefix(std::string, int, int); +void ClearAttribute(Graph*, const std::set&); + void DumpGraphs(const GraphOptimizationPassOptions& options, int idx, std::string filename_prefix, std::string title); diff --git a/test/graph_rewrites/assign_clusters.cc b/test/graph_rewrites/assign_clusters.cc index 0eb010c2e..02bfb6b17 100644 --- a/test/graph_rewrites/assign_clusters.cc +++ b/test/graph_rewrites/assign_clusters.cc @@ -103,6 +103,7 @@ TEST(AssignClusters, ConstToStatic) { // Node1-->Node2 coalesced // Node1-->Node3 coalesced **actually invalid, because Node1 is now in same // cluster as Node2, and we can't contract 2 & 3. +// Also tests ResetAssignClusters TEST(AssignClusters, Cone) { Graph g(OpRegistry::Global()); @@ -147,11 +148,24 @@ TEST(AssignClusters, Cone) { ASSERT_OK(AssignClusters(&g)); - int node2_cluster, node3_cluster; + int node1_cluster, node2_cluster, node3_cluster; + ASSERT_OK(GetNodeCluster(node1, &node1_cluster)); ASSERT_OK(GetNodeCluster(node2, &node2_cluster)); ASSERT_OK(GetNodeCluster(node3, &node3_cluster)); ASSERT_NE(node2_cluster, node3_cluster); + ASSERT_EQ(node1_cluster, node2_cluster); + + ResetAssignClusters(&g); + // After the reset function the attribute should have disappeared, and using + // GetNodeCluster should return -1 + ASSERT_NOT_OK(GetNodeCluster(node1, &node1_cluster)); + ASSERT_NOT_OK(GetNodeCluster(node2, &node2_cluster)); + ASSERT_NOT_OK(GetNodeCluster(node3, &node3_cluster)); + + ASSERT_EQ(node1_cluster, -1); + ASSERT_EQ(node2_cluster, -1); + ASSERT_EQ(node3_cluster, -1); } } // namespace testing diff --git a/test/graph_rewrites/disable_ops_test.cc b/test/graph_rewrites/disable_ops_test.cc index 568dc6d59..434dc20aa 100644 --- a/test/graph_rewrites/disable_ops_test.cc +++ b/test/graph_rewrites/disable_ops_test.cc @@ -114,9 +114,7 @@ TEST(DisableOps, DisableTest) { GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked)); ASSERT_TRUE(marked); - node1->ClearAttr("_ngraph_marked_for_clustering"); - node2->ClearAttr("_ngraph_marked_for_clustering"); - node3->ClearAttr("_ngraph_marked_for_clustering"); + ResetMarkForClustering(&g); // Add is disabled config::ngraph_set_disabled_ops("Add,Mul"); @@ -132,9 +130,7 @@ TEST(DisableOps, DisableTest) { ASSERT_NOT_OK( GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked)); - node1->ClearAttr("_ngraph_marked_for_clustering"); - node2->ClearAttr("_ngraph_marked_for_clustering"); - node3->ClearAttr("_ngraph_marked_for_clustering"); + ResetMarkForClustering(&g); // Add,Add,Mul,Add should work too config::ngraph_set_disabled_ops("Add,Add,Mul,Add"); @@ -150,9 +146,7 @@ TEST(DisableOps, DisableTest) { ASSERT_NOT_OK( GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked)); - node1->ClearAttr("_ngraph_marked_for_clustering"); - node2->ClearAttr("_ngraph_marked_for_clustering"); - node3->ClearAttr("_ngraph_marked_for_clustering"); + ResetMarkForClustering(&g); // Resetting it. So Add should be accepted now config::ngraph_set_disabled_ops(""); @@ -169,9 +163,7 @@ TEST(DisableOps, DisableTest) { GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked)); ASSERT_TRUE(marked); - node1->ClearAttr("_ngraph_marked_for_clustering"); - node2->ClearAttr("_ngraph_marked_for_clustering"); - node3->ClearAttr("_ngraph_marked_for_clustering"); + ResetMarkForClustering(&g); // Invalid op name should trigger an error config::ngraph_set_disabled_ops("Add,_InvalidOp"); diff --git a/test/graph_rewrites/mark_for_clustering_test.cc b/test/graph_rewrites/mark_for_clustering_test.cc index 2ae15c294..418230cba 100644 --- a/test/graph_rewrites/mark_for_clustering_test.cc +++ b/test/graph_rewrites/mark_for_clustering_test.cc @@ -58,6 +58,12 @@ TEST(MarkForClustering, SimpleTest) { .Attr("T", DT_FLOAT) .Finalize(&g, &node3)); + Node* node4; + ASSERT_OK(NodeBuilder("node4", "Abs") + .Input(node3, 0) + .Attr("T", DT_FLOAT) + .Finalize(&g, &node4)); + // Add edges from SRC to node1 and node2 // Add edge from node3 to SINK // The graph is disconnected without these edges @@ -65,7 +71,7 @@ TEST(MarkForClustering, SimpleTest) { Node* sink = g.sink_node(); g.AddEdge(source, Graph::kControlSlot, node1, Graph::kControlSlot); g.AddEdge(source, Graph::kControlSlot, node2, Graph::kControlSlot); - g.AddEdge(node3, Graph::kControlSlot, sink, Graph::kControlSlot); + g.AddEdge(node4, Graph::kControlSlot, sink, Graph::kControlSlot); const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND"); string expected_backend{"CPU"}; @@ -75,9 +81,20 @@ TEST(MarkForClustering, SimpleTest) { ASSERT_OK(MarkForClustering(&g, {}, expected_backend)); string backend; + const set nodes_expected_to_be_marked{"node1", "node2", "node3", + "node4"}; for (auto node : g.op_nodes()) { ASSERT_OK(GetNodeBackend(node, &backend)); ASSERT_EQ(backend, expected_backend); + ASSERT_EQ(nodes_expected_to_be_marked.find(node->name()) != + nodes_expected_to_be_marked.end(), + NodeIsMarkedForClustering(node)); + } + + ResetMarkForClustering(&g); + for (auto node : g.op_nodes()) { + ASSERT_NOT_OK(GetNodeBackend(node, &backend)); + ASSERT_FALSE(NodeIsMarkedForClustering(node)); } } }