Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTEL MKL] Fix for regressions in //tensorflow/python/ops/parallel_for #21007

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
76 changes: 36 additions & 40 deletions tensorflow/core/graph/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2495,13 +2495,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsLRN, LrnRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
mkl_op_registry::GetMklOpName(csinfo_.lrn_grad),
CopyAttrsLRN, LrnRewrite});
CopyAttrsLRN, LrnGradRewrite});
rinfo_.push_back({csinfo_.max_pool,
mkl_op_registry::GetMklOpName(csinfo_.max_pool),
CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
CopyAttrsPooling, MaxpoolGradRewrite});

rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
Expand Down Expand Up @@ -2887,6 +2887,37 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
return false;
}

static bool LrnGradRewrite(const Node* n) {
CHECK_NOTNULL(n);
bool do_rewrite = false;

for (const Edge* e : n->in_edges()) {
// Rewrite only if there is corresponding LRN, i.e workspace is available
if (e->dst()->type_string() == csinfo_.lrn_grad && e->dst_input() == 2 &&
e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.lrn) &&
e->src_output() == 0) {
do_rewrite = true;
break;
}
}
return do_rewrite;
}

static bool MaxpoolGradRewrite(const Node* n) {
CHECK_NOTNULL(n);
bool do_rewrite = false;
for (const Edge* e : n->in_edges()) {
// Rewrite only if there is corresponding Maxpool, i.e workspace is available
if (e->dst()->type_string() == csinfo_.max_pool_grad && e->dst_input() == 1 &&
e->src()->type_string() == mkl_op_registry::GetMklOpName(csinfo_.max_pool) &&
e->src_output() == 0) {
do_rewrite = true;
break;
}
}
return do_rewrite;
}

static bool AddNRewrite(const Node* n) {
CHECK_NOTNULL(n);

Expand Down Expand Up @@ -3421,44 +3452,9 @@ Status MklLayoutRewritePass::SetUpInputs(
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
// We use a tensor of shape {1} and value 0 to represent
// dummy float tensor. We need this as a dummy workspace tensor.
// Workspace tensor has type uint8.
const DataType dt = DataTypeToEnum<uint8>::v();
TensorProto proto;
proto.set_dtype(dt);
float zero[1] = {0};
proto.set_tensor_content(string(reinterpret_cast<char*>(&zero), 4));
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// same the device as the
// device of the original
// node.
.Finalize(&**g, out));

// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
// the dummy Mkl node. This is needed because control-flow ops such as Enter,
// Merge, etc, require frame_name of the dummy Mkl node to be same as the
// rewritten node. Adding control edge between 1st input of the original node
// and the dummy Mkl node ensures that the dummy node is in the same frame
// as the original node. Choosing 1st input is not necessary - any input of
// the original node is fine because all the inputs of a node are always in
// the same frame.
if (orig_node->num_inputs() > 0) {
Node* orig_input0 = nullptr;
TF_CHECK_OK(
orig_node->input_node(0, const_cast<const Node**>(&orig_input0)));
// Allow duplicate while adding control edge as it would fail (return
// NULL) if we try to add duplicate edge.
CHECK_NOTNULL((*g)->AddControlEdge(orig_input0, *out, true));
}

(*out)->set_assigned_device_name(orig_node->assigned_device_name());
// We use uint8 tensor of shape 8 with content {0,0,0,0,0,0,0,0} to represent
// workspace tensor.
GetDummyMklTensorNode(g, out, orig_node);
}

void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
Expand Down
30 changes: 9 additions & 21 deletions tensorflow/core/graph/mkl_layout_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3015,12 +3015,8 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
"A(Input);B(Input);C(Input);D(LRNGrad);"
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}

/* Test LRN->LRNGrad negative case, where single LRN feeds
Expand Down Expand Up @@ -3058,15 +3054,11 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
" input: ['E', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Zeta)|A->B;"
"A:control->DMT/_0:control;B->E:2;"
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;"
"C:control->DMT/_1:control;C:control->DMT/_2:control;"
"C:control->DMT/_3:control;C:control->DMT/_4:control;"
"C:control->DMT/_5:control;C:control->DMT/_6:control;"
"D->E:1;D->F:2;DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;DMT/_3->F:3;"
"DMT/_4->F:7;DMT/_5->F:4;DMT/_6->F:6;E->G;F->G:1");
"DMT/_2(Const);E(_MklLRNGrad);F(LRNGrad);G(Zeta)|A->B;"
"A:control->DMT/_0:control;B->E:2;B->F:1;B:1->E:3;B:2->E:6;"
"B:3->E:7;C->E;C->F;C:control->DMT/_1:control;"
"C:control->DMT/_2:control;D->E:1;D->F:2;DMT/_0->B:1;"
"DMT/_1->E:4;DMT/_2->E:5;E->G;F->G:1");
}

/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
Expand Down Expand Up @@ -3137,12 +3129,8 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
"node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Zeta)|"
"A->D;A->E;A:control->DMT/_0:control;A:control->DMT/_1:control;"
"A:control->DMT/_2:control;A:control->DMT/_3:control;"
"A:control->DMT/_4:control;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;"
"DMT/_1->D:7;DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
"A(Input);B(Input);C(Input);D(MaxPoolGrad);"
"E(Zeta)|A->D;A->E;B->D:1;C->D:2;D->E:1");
}

// Test MaxPool handling for batch-wise pooling (NCHW)
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/kernels/mkl_reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,12 @@ class MklReshapeOp : public OpKernel {
// If Tensorflow's data format and the underlying format maintained by
// MKLDNN are equivalent (both are NHWC or both are NCHW), then we can
// safely return true.
// @todo: Future do not force skip reorder for all blocked format. Use
// blocking_desc_is_equal() for checking all the stride arrays in
// mkl-dnn/blob/master/src/common/type_helpers.hpp
auto input_mkl_md = mkl_shape_input.GetMklLayout();
if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format) {
if (mkl_shape_input.GetTfDataFormat() == input_mkl_md.data.format &&
mkl_shape_input.GetTfDataFormat() != memory::format::blocked) {
ret = true;
}

Expand Down