Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Intel MKL] Fixing bfloat16 integration in MatMul and BatchMatMul for… #37954

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 54 additions & 32 deletions tensorflow/core/kernels/mkl_batch_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include <vector>

#include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
Expand All @@ -44,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/matmul_bcast.h"
#include "tensorflow/core/util/mkl_util.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"

namespace tensorflow {

Expand Down Expand Up @@ -148,35 +148,56 @@ class BatchMatMulMkl : public OpKernel {
std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N);
std::vector<MKL_INT> ldc_array(batch_size, N);
std::vector<MKL_INT> group_size(1, batch_size);
std::vector<const Scalar*> a_array;
std::vector<const Scalar*> b_array;
std::vector<Scalar*> c_array;
a_array.reserve(batch_size);
b_array.reserve(batch_size);
c_array.reserve(batch_size);

if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(i, 0, 0));
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}

if (std::is_same<Scalar, bfloat16>::value) {
// DNNL bfloat16 API requires a, b, and c as pointers to tensors
// represented as flat-byte array.
const Scalar* a = nullptr;
const Scalar* b = nullptr;
OP_REQUIRES(ctx, !bcast.IsBroadcastingRequired(),
errors::Unimplemented("Broadcasting is not supported for "
"BFloat16 _MklBatchMatMul yet."));
a = &lhs_reshaped(0, 0, 0);
b = &rhs_reshaped(0, 0, 0);
Scalar* c = &out_reshaped(0, 0, 0);
// TODO(nhasabni): Use appropriate cast instead of passing addresses of
// a,b and c.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a, lda_array, &b, ldb_array, &c, ldc_array, 1,
group_size);
} else {
// Broadcasting is needed, so get the mapping from flattened output batch
// indices to x's and y's flattened batch indices.
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();

for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0));
b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
std::vector<const Scalar*> a_array;
std::vector<const Scalar*> b_array;
std::vector<Scalar*> c_array;
a_array.reserve(batch_size);
b_array.reserve(batch_size);
c_array.reserve(batch_size);

if (!bcast.IsBroadcastingRequired()) {
for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(i, 0, 0));
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
} else {
// Broadcasting is needed, so get the mapping from flattened output batch
// indices to x's and y's flattened batch indices.
const std::vector<int64>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64>& b_batch_indices = bcast.y_batch_indices();

for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(a_batch_indices[i], 0, 0));
b_array.push_back(&rhs_reshaped(b_batch_indices[i], 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
}
}

MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array, k_array,
&a_array[0], lda_array, &b_array[0], ldb_array,
&c_array[0], ldc_array, 1, group_size);
// MKL CBLAS API requires a, b, and c as array of pointers, where each
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a_array, b_array, and c_array are only used here. Please put lines 151-174 in else too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks.

// pointer is to 2D matrix.
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, m_array, n_array,
k_array, &a_array[0], lda_array, &b_array[0], ldb_array,
&c_array[0], ldc_array, 1, group_size);
}
}

private:
Expand Down Expand Up @@ -269,10 +290,11 @@ class BatchMatMulMkl : public OpKernel {
std::vector<bool> TransB_Array(group_size[0], TransB);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
// TODO(nhasabni): Remove *A when we pass a, b, and c correctly.
// MKLDNN API does not require lda, ldb, and ldc.
dnnl_gemm_batch<bfloat16>(TransA_Array, TransB_Array, M_Array, N_Array,
K_Array, alpha_Array, A_Array, lda_Array, B_Array,
ldb_Array, beta_Array, C_Array, ldc_Array,
group_count, group_size);
K_Array, alpha_Array, *A_Array, *B_Array,
beta_Array, *C_Array, group_count, group_size);
}
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
};
Expand Down Expand Up @@ -302,10 +324,10 @@ TF_CALL_double(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL_V2);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL_V2);

#if defined(ENABLE_INTEL_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
#if defined(ENABLE_MKLDNN_V1) && defined(ENABLE_INTEL_MKL_BFLOAT16)
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_bfloat16(REGISTER_BATCH_MATMUL_MKL_V2);
#endif // ENABLE_INTEL_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
#endif // ENABLE_MKLDNN_V1 && ENABLE_INTEL_MKL_BFLOAT16
#endif // ENABLE_MKL

} // end namespace tensorflow
Expand Down
16 changes: 8 additions & 8 deletions tensorflow/core/kernels/mkl_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,18 @@ class MklMatMulOp : public OpKernel {
const int ldc) {
const float alpha = 1.0f;
const float beta = 0.0f;
const char* const ftrans[] = {"N", "T", "C"};
const int index_transa = transa ? 1 : 0;
const int index_transb = transb ? 1 : 0;

#ifdef ENABLE_MKLDNN_V1
dnnl_gemm<bfloat16>(transa ? CblasTrans : CblasNoTrans,
transb ? CblasTrans : CblasNoTrans, m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc);
#else
Tensor c_float;
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, {m, n}, &c_float));
#ifdef ENABLE_MKLDNN_V1
const char ftrans[] = {'N', 'T', 'C'};
dnnl_gemm<bfloat16>(ftrans[index_transa], ftrans[index_transb], m, n, k,
alpha, a, lda, b, ldb, beta,
c_float.flat<float>().data(), ldc);
#else
const char* const ftrans[] = {"N", "T", "C"};

// MKL-DNN only supports the Fortran API and requires column major while
// Tensorflow uses row major so we reverse the order of A and B.
Expand All @@ -200,9 +201,8 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const mkldnn_bfloat16_t*>(b), &ldb,
reinterpret_cast<const mkldnn_bfloat16_t*>(a), &lda,
&beta, c_float.flat<float>().data(), &ldc);

FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
#endif // ENABLE_MKLDNN_V1
FloatToBFloat16(c_float.flat<float>().data(), c, c_float.NumElements());
}
#endif // ENABLE_INTEL_MKL_BFLOAT16

Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/kernels/mkl_matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklFusedMatMulOp<CPUDevice, type>);
TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_bfloat16(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES);

} // namespace tensorflow

Expand Down
38 changes: 18 additions & 20 deletions tensorflow/core/kernels/mkl_matmul_ops_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,16 +548,14 @@ template <typename T>
void dnnl_gemm_batch(const std::vector<bool>& transa,
const std::vector<bool>& transb, const std::vector<int>& m,
const std::vector<int>& n, const std::vector<int>& k,
const std::vector<float>& alpha, const T** a,
const std::vector<int>& lda, const T** b,
const std::vector<int>& ldb,
const std::vector<float>& beta, T** c,
const std::vector<int>& ldc, const int group_count,
const std::vector<float>& alpha, const T* a, const T* b,
const std::vector<float>& beta, T* c,
const int group_count,
const std::vector<int>& group_size) {
// Current BatchMatMul support in Tensorflow is narrower than the one offered
// by MKL and MKL-DNN. Current BatchMatMul support in Tensorflow uses only 1
// group of size equal to batch_size, and all MatMul parameters (m, n, k,
// lda, ldb, ldc, alpha, beta) within that group are same.
// alpha, beta) within that group are same.
DCHECK(group_size.size() == 1);
DCHECK(transa.size() == group_size[0]);
DCHECK(transb.size() == group_size[0]);
Expand All @@ -566,9 +564,6 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
DCHECK(m.size() == group_size[0]);
DCHECK(n.size() == group_size[0]);
DCHECK(k.size() == group_size[0]);
DCHECK(lda.size() == group_size[0]);
DCHECK(ldb.size() == group_size[0]);
DCHECK(ldc.size() == group_size[0]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
DCHECK(transa[0] == transa[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++)
Expand All @@ -580,21 +575,24 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(m[0] == m[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(n[0] == n[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(k[0] == k[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(lda[0] == lda[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(ldb[0] == ldb[idx]);
for (int64_t idx = 0; idx < group_size[0]; idx++) DCHECK(ldc[0] == ldc[idx]);

using dims = mkldnn::memory::dims;
// Prepare strides based on the transa and transb flags: transposed
// matrices have strides swapped BatchMatMul in MKL-DNN supports 3D metrices
// so far. That is why strides are 3D also.
dims a_strides = transa[0] ? dims{lda[0], 1, 1} : dims{1, 1, lda[0]};
dims b_strides = transb[0] ? dims{ldb[0], 1, 1} : dims{1, 1, ldb[0]};
dims c_strides = dims{ldc[0], 1, 1};
dims a_sizes = dims{group_size[0], m[0], k[0]};
dims b_sizes = dims{group_size[0], k[0], n[0]};
dims c_sizes = dims{group_size[0], m[0], n[0]};
dims a_strides =
!transa[0] ? dims{m[0] * k[0], k[0], 1} : dims{k[0] * m[0], 1, m[0]};
dims b_strides =
!transb[0] ? dims{k[0] * n[0], n[0], 1} : dims{n[0] * k[0], 1, k[0]};
dims c_strides = dims{m[0] * n[0], n[0], 1};

// Prepare memory descriptors
memory::desc a_md({group_size[0], m[0], k[0]}, MklDnnType<T>(), a_strides);
memory::desc b_md({group_size[0], k[0], n[0]}, MklDnnType<T>(), b_strides);
memory::desc c_md({group_size[0], m[0], n[0]}, MklDnnType<T>(), c_strides);
memory::desc a_md(a_sizes, MklDnnType<T>(), a_strides);
memory::desc b_md(b_sizes, MklDnnType<T>(), b_strides);
memory::desc c_md(c_sizes, MklDnnType<T>(), c_strides);
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha[0] != 1.f) attr.set_output_scales(/* mask */ 0, {alpha[0]});
Expand All @@ -610,7 +608,7 @@ void dnnl_gemm_batch(const std::vector<bool>& transa,
template <typename T>
void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
float alpha, const T* a, int64_t lda, const T* b, int64_t ldb,
float beta, T* c, int64_t ldc) {
float beta, float* c, int64_t ldc) {
using dims = mkldnn::memory::dims;
// Prepare strides based on the transa and transb flags: transposed
// matrices have strides swapped
Expand All @@ -619,7 +617,7 @@ void dnnl_gemm(char transa, char transb, int64_t m, int64_t n, int64_t k,
// Prepare memory descriptors
memory::desc a_md({m, k}, MklDnnType<T>(), a_strides);
memory::desc b_md({k, n}, MklDnnType<T>(), b_strides);
memory::desc c_md({m, n}, MklDnnType<T>(), {ldc, 1});
memory::desc c_md({m, n}, MklDnnType<float>(), {ldc, 1});
// Create attributes (to handle alpha and beta if necessary)
mkldnn::primitive_attr attr;
if (alpha != 1.f) attr.set_output_scales(/* mask */ 0, {alpha});
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/core/kernels/mkl_tmp_bf16_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,11 @@ namespace tensorflow {
.Device(DEVICE_CPU) \
.TypeConstraint<bfloat16>("T") \
.TypeConstraint<float>("U"), \
NoOp);
NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp); \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), NoOp);

TF_CALL_bfloat16(REGISTER_CPU);
#undef REGISTER_CPU
Expand Down