diff --git a/ngraph_bridge/enable_variable_ops/ngraph_replace_op_utilities.cc b/ngraph_bridge/enable_variable_ops/ngraph_replace_op_utilities.cc index e45104875..4f535ca5b 100644 --- a/ngraph_bridge/enable_variable_ops/ngraph_replace_op_utilities.cc +++ b/ngraph_bridge/enable_variable_ops/ngraph_replace_op_utilities.cc @@ -194,11 +194,15 @@ Status ReplaceVariable(Graph* graph, Node* node, Node** replacement, // Though edges will be removed when we remove the node // we specifically remove the edges to be sure Status ReplaceInputControlEdges(Graph* graph, Node* node, Node* replacement) { + std::vector edges_to_remove; for (auto edge : node->in_edges()) { NGRAPH_VLOG(4) << "Replacing: " << edge->DebugString(); if (!edge->IsControlEdge()) continue; graph->AddEdge(edge->src(), edge->src_output(), replacement, edge->dst_input()); + edges_to_remove.push_back(edge); + } + for (auto edge : edges_to_remove) { graph->RemoveEdge(edge); } return Status::OK(); @@ -208,6 +212,7 @@ Status ReplaceInputControlEdges(Graph* graph, Node* node, Node* replacement) { // we specifically remove the edges to be sure Status ReplaceOutputEdges(Graph* graph, Node* node, Node* replacement) { std::vector edges; + std::vector edges_to_remove; for (auto edge : node->out_edges()) { edges.push_back(edge); } @@ -216,9 +221,11 @@ Status ReplaceOutputEdges(Graph* graph, Node* node, Node* replacement) { NGRAPH_VLOG(4) << "Replacing: " << edge->DebugString(); graph->AddEdge(replacement, edge->src_output(), edge->dst(), edge->dst_input()); + edges_to_remove.push_back(edge); + } + for (auto edge : edges_to_remove) { graph->RemoveEdge(edge); } - return Status::OK(); } diff --git a/ngraph_bridge/ngraph_capture_variables.cc b/ngraph_bridge/ngraph_capture_variables.cc index a751a8479..915fa869a 100644 --- a/ngraph_bridge/ngraph_capture_variables.cc +++ b/ngraph_bridge/ngraph_capture_variables.cc @@ -98,20 +98,18 @@ Status CaptureVariables(Graph* graph, const std::set skip_these_nodes) { edge->dst_input()); edges_to_remove.push_back(edge); } - // Though edges will be removed when we remove the node - // we specifically remove the edges to be sure - for (auto edge : edges_to_remove) { - graph->RemoveEdge(edge); - } for (auto edge : node->out_edges()) { - edges.push_back(edge); - } - - for (auto edge : edges) { NGRAPH_VLOG(4) << "Replacing: " << edge->DebugString(); graph->AddEdge(replacement, edge->src_output(), edge->dst(), edge->dst_input()); + edges_to_remove.push_back(edge); + } + + // Though edges will be removed when we remove the node + // we specifically remove the edges to be sure + for (auto edge : edges_to_remove) { + NGRAPH_VLOG(4) << "Removing: " << edge->DebugString(); graph->RemoveEdge(edge); } diff --git a/ngraph_bridge/ngraph_rewrite_for_tracking.cc b/ngraph_bridge/ngraph_rewrite_for_tracking.cc index add5871e9..c39013eee 100644 --- a/ngraph_bridge/ngraph_rewrite_for_tracking.cc +++ b/ngraph_bridge/ngraph_rewrite_for_tracking.cc @@ -90,23 +90,26 @@ Status RewriteForTracking(Graph* graph, int graph_id) { NGRAPH_VLOG(4) << "Replacing Node " << node->DebugString() << " with " << replacement->DebugString(); - // Though edges will be removed when we remove the node - // we specifically remove the edges to be sure + std::vector edges_to_remove; + for (auto edge : node->in_edges()) { - NGRAPH_VLOG(4) << "Replacing: " << edge->DebugString(); + NGRAPH_VLOG(4) << "Replacing: In Edge " << edge->DebugString(); graph->AddEdge(edge->src(), edge->src_output(), replacement, edge->dst_input()); - graph->RemoveEdge(edge); + edges_to_remove.push_back(edge); } - std::vector edges; for (auto edge : node->out_edges()) { - edges.push_back(edge); - } - for (auto edge : edges) { - NGRAPH_VLOG(4) << "Replacing: " << edge->DebugString(); + NGRAPH_VLOG(4) << "Replacing: OutEdge " << edge->DebugString(); graph->AddEdge(replacement, edge->src_output(), edge->dst(), edge->dst_input()); + edges_to_remove.push_back(edge); + } + + // Though edges will be removed when we remove the node + // we specifically remove the edges to be sure + for (auto edge : edges_to_remove) { + NGRAPH_VLOG(4) << "Removing: Edges " << edge->DebugString(); graph->RemoveEdge(edge); }