Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Fix for convrnn unit test failure #19229

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
148 changes: 141 additions & 7 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2691,14 +2691,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {

// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
!str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
!str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}

// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
!str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
!str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
Expand Down Expand Up @@ -2865,9 +2865,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}

// If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
// path. The unoptimized path is slow. Thus we dont rewrite the node
// and use default Eigen. But for depth_radius=2, MKL DNN optimized
// If the depth_radius of LRN is not 2, then MKL DNN takes unoptimized
// path. The unoptimized path is slow. Thus we dont rewrite the node
// and use default Eigen. But for depth_radius=2, MKL DNN optimized
// path is taken, i.e., eigen node is rewritten by MKl DNN node.
static bool LrnRewrite(const Node* n) {
CHECK_NOTNULL(n);
Expand All @@ -2876,13 +2876,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CHECK_EQ(GetNodeAttr(n->def(), "depth_radius", &depth_radius).ok(), true);

// if the depth_radius of LRN is not 2, don't rewrite the node by MKL DNN
// and use eigen node instead
// and use eigen node instead
if (depth_radius == 2) {
return true;
}
VLOG(1) << "LrnRewrite: The model sets depth_radius as not 2 which"
<< "case is not optimized by Intel MKL, thus using Eigen op"
<< "for LRN " ;
<< "for LRN ";

return false;
}
Expand Down Expand Up @@ -3015,6 +3015,35 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
std::vector<NodeBuilder::NodeOut>* ws_tensors,
bool* are_ws_tensors_added);

// Helper function used by FixMklMetaDataEdges. Fixes the metadata edge
// pointed by 'e_metadata' corresponding to the data edge 'e_data' in graph
// 'g'. Returns true is fixup was done; otherwise, it returns false.
bool FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
const Edge* e_data, const Edge* e_metadata);

// Are the input Mkl metadata edges for node 'n' in graph 'g' correctly
// connected? If not, then fix them. This is needed because a graph may have
// some input Mkl metadata edges incorrectly setup after node merge and
// rewrite passes. This could happen because GetReversePostOrder function may
// not provide topologically sorted order if a graph contains cycles. The
// function returns true if at least one Mkl metadata edge for node 'n' was
// fixed. Otherwise, it returns false.
//
// Example:
//
// X = MklConv2D(_, _, _)
// Y = MklConv2DWithBias(_, _, _, _, _, _)
// Z = MklAdd(X, Y, DummyMklTensor, Y:1)
//
// For a graph such as shown above, note that 3rd argument of MklAdd contains
// DummyMklTensor. Actually, it should be getting the Mkl metadata from
// MklConv2D op (specifically, X:2). This incorrect plumbing could be possible
// (although rare) if the Mkl NodeMerge + NodeRewrite passes visit Z before X
// (possible if X, Y, Z are part of a loop.) This function fixes the Mkl
// metadata edges only - it does not rewrite nodes nor does it modify the Mkl
// data edges (1st and 2nd arguments of MklAdd).
bool FixMklMetaDataEdges(std::unique_ptr<Graph>* g, Node* n);

// Functions specific to operators to copy attributes
// We need operator-specific function to copy attributes because the framework
// does not provide any generic function for it.
Expand Down Expand Up @@ -4241,6 +4270,92 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
return nullptr;
}

///////////////////////////////////////////////////////////////////////////////
// Post-rewrite Mkl metadata fixup pass
///////////////////////////////////////////////////////////////////////////////
bool MklLayoutRewritePass::FixMklMetaDataEdgeIfNeeded(std::unique_ptr<Graph>* g,
const Edge* e_data, const Edge* e_metadata) {
if (g == nullptr || e_data == nullptr || e_metadata == nullptr) {
return false;
}

Node* n_data = e_data->src();
int n_data_op_slot = e_data->src_output();
int n_metadata_op_slot = GetTensorMetaDataIndex(n_data_op_slot,
n_data->num_outputs());

// If the source of meta edge is a constant node (producing dummy Mkl metadata
// tensor), then we will need to fix.
if (IsConstant(e_metadata->src())) {
Node* e_metadata_dst = e_metadata->dst();
int e_metadata_in_slot = e_metadata->dst_input();
CHECK_NOTNULL((*g)->AddEdge(n_data, n_metadata_op_slot,
e_metadata_dst, e_metadata_in_slot));

(*g)->RemoveEdge(e_metadata);
return true;
}

return false;
}

bool MklLayoutRewritePass::FixMklMetaDataEdges(std::unique_ptr<Graph>* g,
Node* n) {
bool result = false;

// If graph node is not Mkl node, then return.
DataType T = DT_INVALID;
if (!GetNodeAttr(n->def(), "T", &T).ok() ||
!mkl_op_registry::IsMklOp(n->type_string(), T)) {
return result;
}

// If it is Mkl node, then check if the input edges to this node that carry
// Mkl metadata are linked up correctly with the source node.

// For Mkl nodes, we generate twice the number of input tensors (n for Mkl
// data tensors + n for Mkl metadata tensors). We need to check for correct
// connection of n metadata tensors only.
int num_data_inputs = n->num_inputs() / 2;
for (int idx = 0; idx < num_data_inputs; idx++) {
// Get the edge connecting input slot with index (idx).
const Edge* e = nullptr;
TF_CHECK_OK(n->input_edge(idx, &e));

// If e is control edge, then skip.
if (e->IsControlEdge()) {
continue;
}

// Check that the source node for edge 'e' is Mkl node. If it is not an Mkl
// node, then we don't need to do anything.
Node* e_src = e->src();
if (GetNodeAttr(e_src->def(), "T", &T).ok() &&
mkl_op_registry::IsMklOp(e_src->type_string(), T)) {
// Source node for edge 'e' is Mkl node.
// Destination node and destination input slot of e is node 'n' and 'idx'
// resp.
CHECK_EQ(e->dst(), n);
CHECK_EQ(e->dst_input(), idx);

// Let's get edge that carries Mkl metadata corresponding to Mkl data edge
// 'e'. For that, let's first get the input slot of 'n' where the meta
// edge will feed the value.
int e_meta_in_slot = GetTensorMetaDataIndex(e->dst_input(),
n->num_inputs());
const Edge* e_meta = nullptr;
TF_CHECK_OK(n->input_edge(e_meta_in_slot, &e_meta));

// Let's check if we need to fix this meta edge.
if (FixMklMetaDataEdgeIfNeeded(g, e, e_meta)) {
result = true;
}
}
}

return result;
}

///////////////////////////////////////////////////////////////////////////////
// Run function for the pass
///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -4307,6 +4422,25 @@ bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {

DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite)", &**g);

order.clear();
GetReversePostOrder(**g, &order); // This will give us topological sort.
for (Node* n : order) {
// If node is not an op or it cannot run on CPU device, then skip.
if (!n->IsOp() || !CanOpRunOnCPUDevice(n)) {
continue;
}
if (FixMklMetaDataEdges(g, n)) {
string node_name = n->name();
string op_name = n->type_string();

VLOG(1) << "MklLayoutRewritePass: fixed metadata edges for node "
<< node_name << " with op " << op_name;
result = true;
}
}
DumpGraph("After running MklLayoutRewritePass(NodeMerge+Rewrite+Fixup)",
&**g);

return result;
}

Expand Down
31 changes: 31 additions & 0 deletions tensorflow/core/graph/mkl_layout_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3491,6 +3491,37 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}

/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test

TEST_F(MklLayoutPassTest, PostRewriteFixUpPass) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Const' "
" attr { key: 'dtype' value { type: DT_UINT8 } }"
" attr { key: 'value' value { "
" tensor { dtype: DT_UINT8 tensor_shape { dim { size: 1 } } "
" int_val: 0 } } } }"
"node { name: 'E' op: '_MklAdd'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'A', 'D', 'D']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(_MklConv2D);D(Const);E(_MklAdd);"
"M(_MklInput);N(_MklInput)|A->C;A->E:1;B->C:1;C->E;C:2->E:2;"
"D->E:3;M->C:2;N->C:3");
}

/////////////////////////////////////////////////////////////////////

static void BM_MklLayoutRewritePass(int iters, int op_nodes) {
Expand Down