Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTEL MKL] Enabled Conv2D fprop for MKL-DNN v1.0. #30549

Merged
10 changes: 8 additions & 2 deletions tensorflow/core/graph/mkl_layout_pass.cc
Expand Up @@ -353,9 +353,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mul = "Mul";
csinfo_.squared_difference = "SquaredDifference";
csinfo_.sub = "Sub";
// End - element-wise ops. See note above.
// End - element-wise ops. See note above.
penpornk marked this conversation as resolved.
Show resolved Hide resolved

// NOTE: names are alphabetically sorted.
// NOTE: names are alphabetically sorted.
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
Expand Down Expand Up @@ -389,10 +390,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.conjugate_transpose,
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
#endif // !ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
Expand Down Expand Up @@ -641,6 +644,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
Expand All @@ -653,6 +657,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
Expand Down Expand Up @@ -753,6 +758,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// CheckForMklOp
FuseConv3D,
CopyAttrsConv});
#endif // !ENABLE_MKLDNN_V1
}

// Standard interface to run pass
Expand Down