Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ngraph_bridge/ngraph_assign_clusters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,10 @@ Status GetNodeCluster(const Node* node, int* cluster) {
return s;
}

void ResetAssignClusters(Graph* graph) {
ClearAttribute(graph, {"_ngraph_cluster"});
}

} // namespace ngraph_bridge

} // namespace tensorflow
2 changes: 2 additions & 0 deletions ngraph_bridge/ngraph_assign_clusters.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace tensorflow {
namespace ngraph_bridge {

Status AssignClusters(Graph* graph);
// reset the effect of AssignClusters
void ResetAssignClusters(Graph* graph);
Status GetNodeCluster(const Node* node, int* cluster);

} // namespace ngraph_bridge
Expand Down
5 changes: 5 additions & 0 deletions ngraph_bridge/ngraph_mark_for_clustering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1271,6 +1271,11 @@ void SetNodeBackend(Node* node, const string& backend_name) {
node->AddAttr("_ngraph_backend", backend_name);
}

void ResetMarkForClustering(Graph* graph) {
ClearAttribute(graph, {"_ngraph_marked_for_clustering", "_ngraph_backend",
"_ngraph_static_inputs"});
}

} // namespace ngraph_bridge

} // namespace tensorflow
2 changes: 2 additions & 0 deletions ngraph_bridge/ngraph_mark_for_clustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace ngraph_bridge {

Status MarkForClustering(Graph* graph, std::set<string> skip_these_nodes,
const string& current_backend);
// remove marking, backend and static input nodes attributes
void ResetMarkForClustering(Graph* graph);
Status IsSupportedByBackend(
const Node* node, const ngraph::runtime::Backend* op_backend,
std::map<std::string, std::set<std::shared_ptr<ngraph::Node>>>&
Expand Down
9 changes: 9 additions & 0 deletions ngraph_bridge/ngraph_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,15 @@ bool IsProcessedByNgraphPass(Graph* g) {
return false;
}

void ClearAttribute(Graph* g,
const std::set<string>& attributes_to_be_cleared) {
for (auto node : g->nodes()) {
for (const auto& attr : attributes_to_be_cleared) {
node->ClearAttr(attr);
}
}
}

} // namespace ngraph_bridge

} // namespace tensorflow
2 changes: 2 additions & 0 deletions ngraph_bridge/ngraph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ std::string GraphFilenamePrefix(std::string, int);

std::string GraphFilenamePrefix(std::string, int, int);

void ClearAttribute(Graph*, const std::set<string>&);

void DumpGraphs(const GraphOptimizationPassOptions& options, int idx,
std::string filename_prefix, std::string title);

Expand Down
16 changes: 15 additions & 1 deletion test/graph_rewrites/assign_clusters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ TEST(AssignClusters, ConstToStatic) {
// Node1-->Node2 coalesced
// Node1-->Node3 coalesced **actually invalid, because Node1 is now in same
// cluster as Node2, and we can't contract 2 & 3.
// Also tests ResetAssignClusters
TEST(AssignClusters, Cone) {
Graph g(OpRegistry::Global());

Expand Down Expand Up @@ -147,11 +148,24 @@ TEST(AssignClusters, Cone) {

ASSERT_OK(AssignClusters(&g));

int node2_cluster, node3_cluster;
int node1_cluster, node2_cluster, node3_cluster;
ASSERT_OK(GetNodeCluster(node1, &node1_cluster));
ASSERT_OK(GetNodeCluster(node2, &node2_cluster));
ASSERT_OK(GetNodeCluster(node3, &node3_cluster));

ASSERT_NE(node2_cluster, node3_cluster);
ASSERT_EQ(node1_cluster, node2_cluster);

ResetAssignClusters(&g);
// After the reset function the attribute should have disappeared, and using
// GetNodeCluster should return -1
ASSERT_NOT_OK(GetNodeCluster(node1, &node1_cluster));
ASSERT_NOT_OK(GetNodeCluster(node2, &node2_cluster));
ASSERT_NOT_OK(GetNodeCluster(node3, &node3_cluster));

ASSERT_EQ(node1_cluster, -1);
ASSERT_EQ(node2_cluster, -1);
ASSERT_EQ(node3_cluster, -1);
}

} // namespace testing
Expand Down
16 changes: 4 additions & 12 deletions test/graph_rewrites/disable_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,7 @@ TEST(DisableOps, DisableTest) {
GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked));
ASSERT_TRUE(marked);

node1->ClearAttr("_ngraph_marked_for_clustering");
node2->ClearAttr("_ngraph_marked_for_clustering");
node3->ClearAttr("_ngraph_marked_for_clustering");
ResetMarkForClustering(&g);

// Add is disabled
config::ngraph_set_disabled_ops("Add,Mul");
Expand All @@ -132,9 +130,7 @@ TEST(DisableOps, DisableTest) {
ASSERT_NOT_OK(
GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked));

node1->ClearAttr("_ngraph_marked_for_clustering");
node2->ClearAttr("_ngraph_marked_for_clustering");
node3->ClearAttr("_ngraph_marked_for_clustering");
ResetMarkForClustering(&g);

// Add,Add,Mul,Add should work too
config::ngraph_set_disabled_ops("Add,Add,Mul,Add");
Expand All @@ -150,9 +146,7 @@ TEST(DisableOps, DisableTest) {
ASSERT_NOT_OK(
GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked));

node1->ClearAttr("_ngraph_marked_for_clustering");
node2->ClearAttr("_ngraph_marked_for_clustering");
node3->ClearAttr("_ngraph_marked_for_clustering");
ResetMarkForClustering(&g);

// Resetting it. So Add should be accepted now
config::ngraph_set_disabled_ops("");
Expand All @@ -169,9 +163,7 @@ TEST(DisableOps, DisableTest) {
GetNodeAttr(node3->attrs(), "_ngraph_marked_for_clustering", &marked));
ASSERT_TRUE(marked);

node1->ClearAttr("_ngraph_marked_for_clustering");
node2->ClearAttr("_ngraph_marked_for_clustering");
node3->ClearAttr("_ngraph_marked_for_clustering");
ResetMarkForClustering(&g);

// Invalid op name should trigger an error
config::ngraph_set_disabled_ops("Add,_InvalidOp");
Expand Down
19 changes: 18 additions & 1 deletion test/graph_rewrites/mark_for_clustering_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,20 @@ TEST(MarkForClustering, SimpleTest) {
.Attr("T", DT_FLOAT)
.Finalize(&g, &node3));

Node* node4;
ASSERT_OK(NodeBuilder("node4", "Abs")
.Input(node3, 0)
.Attr("T", DT_FLOAT)
.Finalize(&g, &node4));

// Add edges from SRC to node1 and node2
// Add edge from node3 to SINK
// The graph is disconnected without these edges
Node* source = g.source_node();
Node* sink = g.sink_node();
g.AddEdge(source, Graph::kControlSlot, node1, Graph::kControlSlot);
g.AddEdge(source, Graph::kControlSlot, node2, Graph::kControlSlot);
g.AddEdge(node3, Graph::kControlSlot, sink, Graph::kControlSlot);
g.AddEdge(node4, Graph::kControlSlot, sink, Graph::kControlSlot);

const char* ng_backend_env_value = std::getenv("NGRAPH_TF_BACKEND");
string expected_backend{"CPU"};
Expand All @@ -75,9 +81,20 @@ TEST(MarkForClustering, SimpleTest) {
ASSERT_OK(MarkForClustering(&g, {}, expected_backend));

string backend;
const set<string> nodes_expected_to_be_marked{"node1", "node2", "node3",
"node4"};
for (auto node : g.op_nodes()) {
ASSERT_OK(GetNodeBackend(node, &backend));
ASSERT_EQ(backend, expected_backend);
ASSERT_EQ(nodes_expected_to_be_marked.find(node->name()) !=
nodes_expected_to_be_marked.end(),
NodeIsMarkedForClustering(node));
}

ResetMarkForClustering(&g);
for (auto node : g.op_nodes()) {
ASSERT_NOT_OK(GetNodeBackend(node, &backend));
ASSERT_FALSE(NodeIsMarkedForClustering(node));
}
}
}
Expand Down