Skip to content

Commit

Permalink
Removing visited_node hash table - fixing multinode shape mismatch is…
Browse files Browse the repository at this point in the history
…sue (#12044)

This PR fixes an issue that arises in multinode setup. MklLayoutRewritePass
maintains a hash table for visited nodes, and the hash table is part of the
pass and used for every graph being rewritten. But it looks like in multinode
setup, multiple graphs may be processed simulteneously leading to incorrect
modifications to the hash table. So removing the hash table as we do not
really need it.
  • Loading branch information
nhasabni authored and rmlarsen committed Aug 7, 2017
1 parent 32e5652 commit 4ff7190
Showing 1 changed file with 7 additions and 34 deletions.
41 changes: 7 additions & 34 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -477,27 +477,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static ContextInfo biasaddgrad_matmul_context_;
static ContextInfo biasaddgrad_conv2dwithbias_context_;

/// Hash table to maintain nodes visited in the graph.
std::unordered_set<const Node*> visited_nodes_;

private:
// Check if we rewrote node 'n'
//
// If we rewrote the node, then the rewritten node will produce
// Mkl tensor as output. If we did not rewrite the node, then
// we need to insert dummy Mkl node on the input side.
//
// Returns true if node is rewritten, false otherwise.
inline bool IsRewrittenNode(Node* n) const {
return visited_nodes_.find(n) != visited_nodes_.end();
}

// Mark the node as rewritten
inline void MarkRewrittenNode(Node* n) { visited_nodes_.insert(n); }

// Clear all visited nodes
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }

// Is OpDef::ArgDef a list type? It could be N * T or list(type).
// Refer to opdef.proto for details of list type.
inline bool ArgIsList(const OpDef::ArgDef& arg) const {
Expand Down Expand Up @@ -1087,15 +1067,13 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
if (IsRewrittenNode(n)) {
// If we have visited this node and rewritten it, then it will generate
// an edge that will receive Mkl tensor from a node.
// First, let's assert that this op is Mkl layer.
DataType T;
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
// If this op has been rewritten, then its name must have been same as
// Mkl op.
CHECK_EQ(mkl_op_registry::IsMklOp(n->type_string(), T), true);

// If this is an MKL op, then it will create extra output for MKL layout.
DataType T;
if (GetNodeAttr(n->def(), "T", &T).ok() &&
mkl_op_registry::IsMklOp(n->type_string(), T)) {
// If this is an MKL op, then it will generate an edge that will receive
// Mkl tensor from a node.
// output slot number for Mkl tensor would be N+slot number of TensorFlow
// tensor, where N is total number of TensorFlow tensors.
*mkl_node = n;
Expand Down Expand Up @@ -1801,7 +1779,6 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,

(*g)->RemoveNode(succ);
(*g)->RemoveNode(pred);
MarkRewrittenNode(new_node);

return Status::OK();
}
Expand Down Expand Up @@ -1932,7 +1909,6 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,

// Delete original node and mark new node as rewritten.
(*g)->RemoveNode(orig_node);
MarkRewrittenNode(new_node);

VLOG(1) << "MklLayoutRewritePass: New node:" << new_node->DebugString();
return Status::OK();
Expand Down Expand Up @@ -2062,9 +2038,6 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {

DumpGraph("After running MklLayoutRewritePass", &**g);

// Clear marked nodes as the same graph pass may be used multiple times.
UnMarkRewrittenNodes();

return result;
}

Expand Down

0 comments on commit 4ff7190

Please sign in to comment.