Skip to content

Commit

Permalink
Update MklMatMulPrimitiveFactory to support Arm Compute Library backend
Browse files Browse the repository at this point in the history
Related to issue #47415 and PR #47775. Adding support for caching matmul primitives.
Updates onednn_acl_primitives.patch to include matmul primitives.
  • Loading branch information
cfRod committed Oct 7, 2021
1 parent 232839b commit 6b8ff16
Show file tree
Hide file tree
Showing 3 changed files with 558 additions and 1 deletion.
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc
Expand Up @@ -140,6 +140,15 @@ class BatchMatMulMkl : public OpKernel {
MklBatchMatMulHelper bmm;
auto params = bmm.CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape,
adj_x_, adj_y_);

#ifdef DNNL_AARCH64_USE_ACL
// ACL does not support reuse of primitives with different data.
// For matmul, the previous approach (PR #47775) of using Tensor addresses
// does not work, as the addresses are re-used in matmul with different data
// The counter ensure we still benefit from caching via SetMklMatmul().
static int counter = 1;
params->aarch64_counter = counter++;
#endif
// Create or retrieve matmul primitive from cache.
MklMatMulPrimitive<Scalar>* matmul_prim =
MklMatMulPrimitiveFactory<Scalar>::Get(
Expand Down
7 changes: 6 additions & 1 deletion tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h
Expand Up @@ -536,6 +536,9 @@ struct MklMatMulParams {
memory::dims a_strides;
memory::dims b_strides;
memory::dims c_strides;
#ifdef DNNL_AARCH64_USE_ACL
int aarch64_counter;
#endif

MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims,
memory::dims a_strides, memory::dims b_strides,
Expand Down Expand Up @@ -697,7 +700,9 @@ class MklMatMulPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(params.b_strides);
key_creator.AddAsKey(params.c_strides);
key_creator.AddAsKey(typeid(T).name());

#ifdef DNNL_AARCH64_USE_ACL
key_creator.AddAsKey(params.aarch64_counter);
#endif
return key_creator.GetKey();
}

Expand Down

0 comments on commit 6b8ff16

Please sign in to comment.