Skip to content

Commit

Permalink
Merge pull request #60307 from Intel-tensorflow:bhavanis/onednn-3.0-c…
Browse files Browse the repository at this point in the history
…onv-transpose

PiperOrigin-RevId: 524548684
  • Loading branch information
tensorflower-gardener committed Apr 15, 2023
2 parents 3a1275b + 1ad4ffa commit ce37699
Show file tree
Hide file tree
Showing 42 changed files with 280 additions and 110 deletions.
8 changes: 8 additions & 0 deletions tensorflow/core/common_runtime/mkl_layout_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {

const bool native_fmt = NativeFormatEnabled();
// NOTE: names are alphabetically sorted.
#ifndef ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
Expand Down Expand Up @@ -418,6 +419,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.concatv2,
mkl_op_registry::GetMklOpName(csinfo_.concatv2),
CopyAttrsAll, ConcatV2Rewrite, GetRewriteCause()});
#endif // !ENABLE_ONEDNN_V3
rinfo_.push_back(
{csinfo_.conjugate_transpose,
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
Expand Down Expand Up @@ -462,6 +464,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.depthwise_conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_filter),
CopyAttrsAll, AlwaysRewrite, GetRewriteCause()});
#ifndef ENABLE_ONEDNN_V3
rinfo_.push_back(
{csinfo_.dequantize, mkl_op_registry::GetMklOpName(csinfo_.dequantize),
CopyAttrsAll, DequantizeRewrite, kRewriteForOpNameChange});
Expand Down Expand Up @@ -496,6 +499,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
: csinfo_.mkl_fused_batch_norm_ex,
CopyAttrsAll, FusedBatchNormExRewrite,
GetRewriteCause()});
#endif // !ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.fused_conv2d,
native_fmt ? csinfo_.mkl_native_fused_conv2d
: csinfo_.mkl_fused_conv2d,
Expand All @@ -509,6 +513,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
: csinfo_.mkl_fused_depthwise_conv2d,
CopyAttrsAllCheckConstFilter, FusedDepthwiseConv2DRewrite,
GetRewriteCause()});
#ifndef ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.fused_matmul,
native_fmt ? csinfo_.mkl_native_fused_matmul
: csinfo_.mkl_fused_matmul,
Expand Down Expand Up @@ -549,6 +554,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.mul, mkl_op_registry::GetMklOpName(csinfo_.mul),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
GetRewriteCause()});
#endif // !ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.pad_with_conv2d,
native_fmt ? csinfo_.mkl_native_pad_with_conv2d
: csinfo_.mkl_pad_with_conv2d,
Expand All @@ -559,6 +565,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
: csinfo_.mkl_pad_with_fused_conv2d,
CopyAttrsAllCheckConstFilter, AlwaysRewrite,
GetRewriteCause()});
#ifndef ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.quantized_avg_pool,
mkl_op_registry::GetMklOpName(csinfo_.quantized_avg_pool),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
Expand Down Expand Up @@ -711,6 +718,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.sub, mkl_op_registry::GetMklOpName(csinfo_.sub),
CopyAttrsAll, RewriteIfAtleastOneMklInput,
GetRewriteCause()});
#endif // !ENABLE_ONEDNN_V3
rinfo_.push_back({csinfo_.transpose,
mkl_op_registry::GetMklOpName(csinfo_.transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_avgpooling_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
limitations under the License.
==============================================================================*/

#ifdef INTEL_MKL
#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3)
#define EIGEN_USE_THREADS

#include "dnnl.hpp"
Expand Down Expand Up @@ -438,4 +438,4 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedAvgPool")

} // namespace tensorflow

#endif // INTEL_MKL
#endif // INTEL_MKL && !ENABLE_ONEDNN_V3
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_batch_matmul_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_
#define TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_
#ifdef INTEL_MKL
#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3)

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
Expand Down Expand Up @@ -106,5 +106,5 @@ struct MklBatchMatMulHelper {

} // namespace tensorflow

#endif // INTEL_MKL
#endif // INTEL_MKL && !ENABLE_ONEDNN_V3
#endif // TENSORFLOW_CORE_KERNELS_MKL_MKL_BATCH_MATMUL_HELPER_H_
6 changes: 2 additions & 4 deletions tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ limitations under the License.

#define EIGEN_USE_THREADS

#if defined(INTEL_MKL)
#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3)

#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/register_types.h"
Expand Down Expand Up @@ -327,14 +327,12 @@ class FusedBatchMatMulMkl
.TypeConstraint<TYPE>("T"), \
FusedBatchMatMulMkl<CPUDevice, TYPE, TYPE, TYPE, true>)

#ifdef INTEL_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_float(REGISTER_FUSED_BATCH_MATMUL_MKL);
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_bfloat16(REGISTER_FUSED_BATCH_MATMUL_MKL);
#endif // INTEL_MKL

} // end namespace tensorflow
#endif
#endif // INTEL_MKL && !ENABLE_ONEDNN_V3
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/mkl/mkl_concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifdef INTEL_MKL
#if defined(INTEL_MKL) && !defined(ENABLE_ONEDNN_V3)
#define EIGEN_USE_THREADS

#include <limits>
Expand Down Expand Up @@ -985,4 +985,4 @@ REGISTER_QUANTIZED_CONCATV2(qint8);
#undef REGISTER_CONCAT_MKL
} // namespace tensorflow

#endif // INTEL_MKL
#endif // INTEL_MKL && !ENABLE_ONEDNN_V3
37 changes: 36 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_conv_grad_filter_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;

#ifndef ENABLE_ONEDNN_V3
using ConvBwdFilterDesc = dnnl::convolution_backward_weights::desc;
#endif // !ENABLE_ONEDNN_V3
using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc;

struct MklConvBwdFilterParams {
Expand Down Expand Up @@ -157,11 +159,15 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {

// Primitive descriptor and descriptor for convolution backward filter.
std::shared_ptr<ConvBwdFilterPd> bwd_filter_pd;
#ifndef ENABLE_ONEDNN_V3
std::shared_ptr<ConvBwdFilterDesc> bwd_filter_desc;
#endif // !ENABLE_ONEDNN_V3

// Primitive descriptor and descriptor for convolution forward.
std::shared_ptr<ConvFwdPd> fwd_pd;
#ifndef ENABLE_ONEDNN_V3
std::shared_ptr<ConvFwdDesc> fwd_desc;
#endif // !ENABLE_ONEDNN_V3

// Convolution backward filter primitive.
std::shared_ptr<dnnl::primitive> conv_bwd_filter;
Expand All @@ -182,13 +188,18 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
diff_filter_mem(nullptr),
diff_bias_mem(nullptr),
diff_dst_mem(nullptr),
#ifndef ENABLE_ONEDNN_V3
bwd_filter_desc(nullptr),
#endif // !ENABLE_ONEDNN_V3
fwd_pd(nullptr),
#ifndef ENABLE_ONEDNN_V3
fwd_desc(nullptr),
#endif // !ENABLE_ONEDNN_V3
src_md(nullptr),
diff_filter_md(nullptr),
diff_bias_md(nullptr),
diff_dst_md(nullptr) {}
diff_dst_md(nullptr) {
}
};

void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
Expand Down Expand Up @@ -217,6 +228,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
memory::format_tag::x));

#ifndef ENABLE_ONEDNN_V3
// Create descriptor and primitive descriptor for convolution forward.
context_.fwd_desc.reset(new ConvFwdDesc(
prop_kind::forward, dnnl::algorithm::convolution_direct,
Expand All @@ -242,6 +254,29 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
}
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
#else
context_.fwd_pd.reset(new ConvFwdPd(
cpu_engine_, prop_kind::forward, dnnl::algorithm::convolution_direct,
*context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right));

if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
cpu_engine_, dnnl::algorithm::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_bias_md,
*context_.diff_dst_md, convBwdFilterDims.strides,
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
convBwdFilterDims.padding_right, *context_.fwd_pd));
} else {
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
cpu_engine_, dnnl::algorithm::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
*context_.fwd_pd));
}
#endif // !ENABLE_ONEDNN_V3

auto bwd_filter_pd = context_.bwd_filter_pd.get();

Expand Down
41 changes: 36 additions & 5 deletions tensorflow/core/kernels/mkl/mkl_conv_grad_input_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@ using dnnl::prop_kind;
using dnnl::stream;

namespace tensorflow {
#ifndef ENABLE_ONEDNN_V3
#define SET_MKL_LAYOUT(md) SetMklLayout(&md)
#else
#define SET_MKL_LAYOUT(md) SetMklLayout(md)
#endif // !ENABLE_ONEDNN_V3

#ifndef ENABLE_ONEDNN_V3
using ConvBwdDataDesc = dnnl::convolution_backward_data::desc;
#endif // !ENABLE_ONEDNN_V3
using ConvBwdDataPd = dnnl::convolution_backward_data::primitive_desc;

// Utility classes for enabling primitive reuse for conv bwd input.
Expand Down Expand Up @@ -96,7 +103,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
#ifdef DNNL_AARCH64_USE_ACL
mutex_lock lock(primitive_execution_mu_);
#endif
#ifndef ENABLE_ONEDNN_OPENMP
#if !defined(ENABLE_ONEDNN_OPENMP) && !defined(ENABLE_ONEDNN_V3)
// TODO(intel-tf): Create a common function and avoid the duplicate code
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
Expand All @@ -111,7 +118,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
static_cast<T*>(const_cast<T*>(filter_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#endif // !ENABLE_ONEDNN_OPENMP
#endif // !ENABLE_ONEDNN_OPENMP && !ENABLE_ONEDNN_V3
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
context_.bwd_input_primitives_args);

Expand All @@ -136,11 +143,15 @@ class MklConvBwdInputPrimitive : public MklPrimitive {

// Conv backward input primitive descriptor and descriptor.
std::shared_ptr<ConvBwdDataPd> bwd_input_pd;
#ifndef ENABLE_ONEDNN_V3
std::shared_ptr<ConvBwdDataDesc> bwd_input_desc;
#endif // !ENABLE_ONEDNN_V3

// Primitive descriptor and descriptor for conv fwd
std::shared_ptr<ConvFwdPd> fwd_pd;
#ifndef ENABLE_ONEDNN_V3
std::shared_ptr<ConvFwdDesc> fwd_desc;
#endif // !ENABLE_ONEDNN_V3

// Conv bwd input primitive.
std::shared_ptr<dnnl::primitive> conv_bwd_input;
Expand All @@ -159,13 +170,18 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
filter_mem(nullptr),
diff_dst_mem(nullptr),
bwd_input_pd(nullptr),
#ifndef ENABLE_ONEDNN_V3
bwd_input_desc(nullptr),
#endif // !ENABLE_ONEDNN_V3
fwd_pd(nullptr),
#ifndef ENABLE_ONEDNN_V3
fwd_desc(nullptr),
#endif // !ENABLE_ONEDNN_V3
conv_bwd_input(nullptr),
diff_src_md(nullptr),
filter_md(nullptr),
diff_dst_md(nullptr) {}
diff_dst_md(nullptr) {
}
};

void Setup(const MklConvBwdInputParams& convBwdInputDims) {
Expand All @@ -188,6 +204,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
memory::format_tag::any));

// Create descriptors for both conv fwd and conv bwd input.
#ifndef ENABLE_ONEDNN_V3
context_.bwd_input_desc.reset(new ConvBwdDataDesc(
dnnl::algorithm::convolution_direct, *context_.diff_src_md,
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
Expand All @@ -204,6 +221,19 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
context_.bwd_input_pd.reset(new ConvBwdDataPd(
*context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd));
#else
context_.fwd_pd.reset(new ConvFwdPd(
cpu_engine_, prop_kind::forward, dnnl::algorithm::convolution_direct,
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
convBwdInputDims.strides, convBwdInputDims.dilations,
convBwdInputDims.padding_left, convBwdInputDims.padding_right));

context_.bwd_input_pd.reset(new ConvBwdDataPd(
cpu_engine_, dnnl::algorithm::convolution_direct, *context_.diff_src_md,
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
convBwdInputDims.dilations, convBwdInputDims.padding_left,
convBwdInputDims.padding_right, *context_.fwd_pd));
#endif // !ENABLE_ONEDNN_V3

// Create memory using dummy data.
context_.diff_src_mem.reset(new memory(
Expand Down Expand Up @@ -449,7 +479,7 @@ class MklConvCustomBackpropInputOp
// Allocate output tensor.
MklDnnShape diff_src_mkl_shape;
diff_src_mkl_shape.SetMklTensor(true);
diff_src_mkl_shape.SetMklLayout(&diff_src_pd);
diff_src_mkl_shape.SET_MKL_LAYOUT(diff_src_pd);
diff_src_mkl_shape.SetElemType(MklDnnType<T>());
diff_src_mkl_shape.SetTfLayout(bwd_diff_src_dims.size(),
bwd_diff_src_dims, bwd_diff_src_format);
Expand Down Expand Up @@ -578,7 +608,7 @@ class MklConvCustomBackpropInputOp
// Allocate shape of MKL tensor.
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SET_MKL_LAYOUT(dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
Expand Down Expand Up @@ -634,6 +664,7 @@ TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS);

#undef REGISTER_MKL_CPU_KERNELS
#undef SET_MKL_LAYOUT

} // namespace tensorflow
#endif // INTEL_MKL
Loading

0 comments on commit ce37699

Please sign in to comment.