Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INTEL MKL] Using MKL DNN sgemm #21102

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 22 additions & 8 deletions tensorflow/core/kernels/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,25 +578,39 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)
#if defined(INTEL_MKL)

// MKL does not support half and int32 types for matrix-multiplication, so
// MKL does not support half, bfloat16 and int32 types for matrix-multiplication, so
// register the kernel to use default Eigen based implementations for these
// types. Registration for NO-LABEL version is in mkl_matmul_op.cc
TF_CALL_float(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
// types. REGISTER_CPU defines two versions - Eigen label and NO-LABEL
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);

TF_CALL_int32(REGISTER_CPU);

// Float is supported in both MKL DNN as well as in MKL ML
// Registration for NO-LABEL version is in mkl_matmul_op.cc for types supported by
// MKL. However we define Eigen label version here just to pass a few unit tests
TF_CALL_float(REGISTER_CPU_EIGEN);

// MKL DNN does not support complex64/complex128/double, if user specifies
// to use only opensource MKL DNN then use default implementation for these types
// otherwise use GEMM from MKL ML binary

#if defined(DO_NOT_USE_ML)
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
#else // DO_NOT_USE_ML
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
#else
TF_CALL_double(REGISTER_CPU_EIGEN);
#endif

#else // INTEL MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
TF_CALL_bfloat16(REGISTER_CPU);

TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
Expand Down
31 changes: 28 additions & 3 deletions tensorflow/core/kernels/mkl_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@ limitations under the License.
// and when it is undefined at build time, this file becomes an empty
// compilation unit

#if defined(INTEL_MKL) && !defined(DO_NOT_USE_ML)

#include "mkl_cblas.h"
#if defined(INTEL_MKL)

#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"

//This header file is part of MKL ML, need equivalent file in MKL DNN
#ifndef DO_NOT_USE_ML
#include "mkl_cblas.h"
#else
#include "mkldnn.h"
#endif

namespace tensorflow {

typedef Eigen::ThreadPoolDevice CPUDevice;
Expand Down Expand Up @@ -100,7 +107,6 @@ class MklMatMulOp : public OpKernel {
private:
bool transpose_a_;
bool transpose_b_;

// --------------------------------------------------------------------------
//
// @brief Matrix-Matrix Multiplication with FP32 tensors, a, b, c using CBLAS
Expand Down Expand Up @@ -150,11 +156,26 @@ class MklMatMulOp : public OpKernel {
// 1.0 and 0.0 respectively.
const float alpha = 1.0f;
const float beta = 0.0f;
#if defined(DO_NOT_USE_ML)
const char* const ftrans[] = { "N", "T", "C"};
int index_transa = transa? 1 : 0 ;
int index_transb = transb? 1 : 0 ;
VLOG(2) << "MKL DNN SGEMM called";
// MKL DNN only supports the Fortran api and requires column major while Tensorflow
// uses row major so we reverse the order A and B
mkldnn_sgemm(ftrans[index_transb], ftrans[index_transa],
&n, &m, &k, &alpha, b, &ldb, a, &lda, &beta, c, &ldc);
#else
// MKL ML binary uses CBLAS API
cblas_sgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc);
#endif
}

// MKLDNN only supports SGEMM
#ifndef DO_NOT_USE_ML

// Matrix-Matrix Multiplication with FP64 tensors. For detailed info about
// parameters, look at FP32 function description.
void MklBlasGemm(bool transa, bool transb, const int m, const int n,
Expand Down Expand Up @@ -197,6 +218,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
#endif
};

#define REGISTER_CPU(T) \
Expand All @@ -207,9 +229,12 @@ class MklMatMulOp : public OpKernel {
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);

#ifndef DO_NOT_USE_ML
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
#endif

} // namespace tensorflow
#endif // INTEL_MKL