From 6b8ff1610a10ef1df5503916ff1b052ba4de0949 Mon Sep 17 00:00:00 2001 From: cfRod Date: Tue, 5 Oct 2021 15:20:16 +0100 Subject: [PATCH] Update MklMatMulPrimitiveFactory to support Arm Compute Library backend Related to issue #47415 and PR #47775. Adding support for caching matmul primitives. Updates onednn_acl_primitives.patch to include matmul primitives. --- .../core/kernels/mkl/mkl_batch_matmul_op.cc | 9 + .../core/kernels/mkl/mkl_matmul_ops_common.h | 7 +- .../mkl_dnn/onednn_acl_primitives.patch | 543 ++++++++++++++++++ 3 files changed, 558 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc index a4a22f463a8372..79677299d47cda 100644 --- a/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl/mkl_batch_matmul_op.cc @@ -140,6 +140,15 @@ class BatchMatMulMkl : public OpKernel { MklBatchMatMulHelper bmm; auto params = bmm.CreateMatMulParams(lhs.shape(), rhs.shape(), out_shape, adj_x_, adj_y_); + +#ifdef DNNL_AARCH64_USE_ACL + // ACL does not support reuse of primitives with different data. + // For matmul, the previous approach (PR #47775) of using Tensor addresses + // does not work, as the addresses are re-used in matmul with different data + // The counter ensure we still benefit from caching via SetMklMatmul(). + static int counter = 1; + params->aarch64_counter = counter++; +#endif // Create or retrieve matmul primitive from cache. MklMatMulPrimitive* matmul_prim = MklMatMulPrimitiveFactory::Get( diff --git a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h index 2c17526a3abf9d..4a4a09f69b16f9 100644 --- a/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl/mkl_matmul_ops_common.h @@ -536,6 +536,9 @@ struct MklMatMulParams { memory::dims a_strides; memory::dims b_strides; memory::dims c_strides; +#ifdef DNNL_AARCH64_USE_ACL + int aarch64_counter; +#endif MklMatMulParams(memory::dims a_dims, memory::dims b_dims, memory::dims c_dims, memory::dims a_strides, memory::dims b_strides, @@ -697,7 +700,9 @@ class MklMatMulPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(params.b_strides); key_creator.AddAsKey(params.c_strides); key_creator.AddAsKey(typeid(T).name()); - +#ifdef DNNL_AARCH64_USE_ACL + key_creator.AddAsKey(params.aarch64_counter); +#endif return key_creator.GetKey(); } diff --git a/third_party/mkl_dnn/onednn_acl_primitives.patch b/third_party/mkl_dnn/onednn_acl_primitives.patch index 73314b8a0afbe4..a46f5634570e9f 100644 --- a/third_party/mkl_dnn/onednn_acl_primitives.patch +++ b/third_party/mkl_dnn/onednn_acl_primitives.patch @@ -1870,3 +1870,546 @@ index 755c74550..6b06414b6 100644 CPU_INSTANCE(gemm_inner_product_bwd_data_t) CPU_INSTANCE(gemm_inner_product_bwd_weights_t) +diff --git a/src/cpu/aarch64/matmul/acl_matmul.cpp b/src/cpu/aarch64/matmul/acl_matmul.cpp +new file mode 100644 +index 000000000..3945fda6f +--- /dev/null ++++ b/src/cpu/aarch64/matmul/acl_matmul.cpp +@@ -0,0 +1,87 @@ ++/******************************************************************************* ++* Copyright 2021 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#include "cpu/aarch64/matmul/acl_matmul.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++namespace matmul { ++ ++using namespace data_type; ++ ++status_t acl_matmul_t::execute_forward(const exec_ctx_t &ctx) const { ++ ++ status_t status = status::success; ++ auto src_base = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC); ++ auto wei_base = CTX_IN_MEM(const data_t *, DNNL_ARG_WEIGHTS); ++ auto dst_base = CTX_OUT_MEM(data_t *, DNNL_ARG_DST); ++ ++ bool is_transA = pd()->amp_.is_transA; ++ bool is_transB = pd()->amp_.is_transB; ++ ++ std::lock_guard _lock {this->mtx}; ++ auto *acl_resource = ctx.get_resource_mapper()->get(this); ++ acl_matmul_obj_t &acl_obj = acl_resource->get_acl_obj(); ++ // Run transpose kernel ++ if (is_transA && !is_transB) { ++ acl_obj.src_tensor.allocator()->allocate(); ++ acl_obj.src_acc_tensor.allocator()->import_memory( ++ const_cast(src_base)); ++ acl_obj.transA.run(); ++ acl_obj.wei_tensor.allocator()->import_memory( ++ const_cast(wei_base)); ++ } else if (is_transB && !is_transA) { ++ acl_obj.wei_tensor.allocator()->allocate(); ++ acl_obj.wei_acc_tensor.allocator()->import_memory( ++ const_cast(wei_base)); ++ acl_obj.transB.run(); ++ acl_obj.src_tensor.allocator()->import_memory( ++ const_cast(src_base)); ++ } else if (is_transA && is_transB) { ++ acl_obj.src_tensor.allocator()->allocate(); ++ acl_obj.src_acc_tensor.allocator()->import_memory( ++ const_cast(src_base)); ++ acl_obj.wei_tensor.allocator()->allocate(); ++ acl_obj.wei_acc_tensor.allocator()->import_memory( ++ const_cast(wei_base)); ++ acl_obj.transA.run(); ++ acl_obj.transB.run(); ++ } else { ++ acl_obj.src_tensor.allocator()->import_memory( ++ const_cast(src_base)); ++ acl_obj.wei_tensor.allocator()->import_memory( ++ const_cast(wei_base)); ++ } ++ ++ acl_obj.dst_tensor.allocator()->import_memory(dst_base); ++ ++ acl_obj.gemm.run(); ++ ++ acl_obj.src_tensor.allocator()->free(); ++ acl_obj.wei_tensor.allocator()->free(); ++ acl_obj.dst_tensor.allocator()->free(); ++ if (is_transA) acl_obj.src_acc_tensor.allocator()->free(); ++ if (is_transB) acl_obj.wei_acc_tensor.allocator()->free(); ++ return status; ++} ++ ++} // namespace matmul ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl +\ No newline at end of file +diff --git a/src/cpu/aarch64/matmul/acl_matmul.hpp b/src/cpu/aarch64/matmul/acl_matmul.hpp +new file mode 100644 +index 000000000..6ba17e86d +--- /dev/null ++++ b/src/cpu/aarch64/matmul/acl_matmul.hpp +@@ -0,0 +1,154 @@ ++/******************************************************************************* ++* Copyright 2021 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#ifndef ACL_MATMUL_HPP ++#define ACL_MATMUL_HPP ++ ++#include "cpu/aarch64/matmul/acl_matmul_utils.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++namespace matmul { ++ ++struct acl_resource_t : public resource_t { ++ acl_resource_t() : acl_obj_(utils::make_unique()) {} ++ ++ status_t configure(const acl_matmul_conf_t &) { ++ if (!acl_obj_) return status::out_of_memory; ++ acl_obj_->src_tensor.allocator()->init(amp.src_info); ++ acl_obj_->wei_tensor.allocator()->init(amp.wei_info); ++ acl_obj_->dst_tensor.allocator()->init(amp.dst_info); ++ // Configure transpose kernel for src, wei or both ++ if (amp.is_transA) { ++ acl_obj_->src_acc_tensor.allocator()->init(amp.src_acc_info); ++ acl_obj_->transA.configure( ++ &acl_obj_->src_acc_tensor, &acl_obj_->src_tensor); ++ } ++ if (amp.is_transB) { ++ acl_obj_->wei_acc_tensor.allocator()->init(amp.wei_acc_info); ++ acl_obj_->transB.configure( ++ &acl_obj_->wei_acc_tensor, &acl_obj_->wei_tensor); ++ } ++ // Configure GEMM ++ acl_obj_->gemm.configure(&acl_obj_->src_tensor, &acl_obj_->wei_tensor, ++ nullptr, &acl_obj_->dst_tensor, amp.alpha, 0.0f, amp.gemm_info); ++ return status::success; ++ } ++ acl_matmul_obj_t &get_acl_obj() const { return *acl_obj_; } ++ ++ DNNL_DISALLOW_COPY_AND_ASSIGN(acl_resource_t); ++ ++private: ++ std::unique_ptr acl_obj_; ++}; ++ ++struct acl_matmul_t : public primitive_t { ++ struct pd_t : public dnnl::impl::cpu::matmul::cpu_matmul_pd_t { ++ ++ pd_t(const matmul_desc_t *adesc, const primitive_attr_t *attr, ++ const cpu_matmul_pd_t *hint_fwd_pd) ++ : cpu_matmul_pd_t(adesc, attr, hint_fwd_pd), amp_() {} ++ ++ using cpu_matmul_pd_t::cpu_matmul_pd_t; ++ ++ DECLARE_COMMON_PD_T("gemm:acl", acl_matmul_t, USE_GLOBAL_SCRATCHPAD); ++ ++ status_t init(engine_t *engine) { ++ using smask_t = primitive_attr_t::skip_mask_t; ++ bool ok = src_md()->data_type == data_type::f32 ++ && weights_md()->data_type == data_type::f32 ++ && desc()->accum_data_type == data_type::f32 ++ && dst_md()->data_type == data_type::f32 ++ && platform::has_data_type_support(data_type::f32) ++ && attr()->has_default_values( ++ smask_t::oscale | smask_t::post_ops) ++ && post_ops_ok() && attr_oscale_ok() ++ && !has_runtime_dims_or_strides(); ++ if (!ok) return status::unimplemented; ++ ++ auto conf_status = acl_matmul_utils::init_conf_matmul(amp_, src_md_, ++ weights_md_, dst_md_, bias_md_, *desc(), *attr()); ++ ++ if (conf_status != status::success) return status::unimplemented; ++ // Number of threads in Compute Library is set by OMP_NUM_THREADS ++ // dnnl_get_max_threads() == OMP_NUM_THREADS ++ acl_common_utils::acl_thread_bind(); ++ ++ return status::success; ++ } ++ ++ acl_matmul_conf_t amp_; ++ ++ protected: ++ bool post_ops_ok() const { ++ using namespace data_type; ++ using namespace alg_kind; ++ auto const &po = attr()->post_ops_; ++ auto is_eltwise ++ = [&](int idx) { return po.entry_[idx].is_eltwise(); }; ++ bool eltwise_only = (po.len() == 1) ? is_eltwise(0) : false; ++ bool eltwise_ok = false; ++ if (eltwise_only) { ++ const auto act_type = po.entry_[0].eltwise.alg; ++ eltwise_ok = acl_matmul_utils::acl_act_ok(act_type); ++ } ++ return eltwise_ok || (po.len() == 0); ++ } ++ ++ bool attr_oscale_ok() const { ++ const auto &oscale = attr()->output_scales_; ++ return oscale.mask_ == 0; ++ } ++ }; ++ ++ acl_matmul_t(const pd_t *apd) : primitive_t(apd) {} ++ ++ status_t create_resource( ++ engine_t *engine, resource_mapper_t &mapper) const override { ++ if (mapper.has_resource(this)) return status::success; ++ auto r = utils::make_unique(); ++ if (!r) return status::out_of_memory; ++ ++ // Configure the resource based on information from primitive descriptor ++ auto st = r->configure(pd()->amp_); ++ if (st == status::success) { mapper.add(this, std::move(r)); } ++ ++ return st; ++ } ++ ++ typedef typename prec_traits::type data_t; ++ ++ status_t execute(const exec_ctx_t &ctx) const override { ++ return execute_forward(ctx); ++ } ++ ++private: ++ // To guard the const execute_forward(), the mutex must be 'mutable' ++ mutable std::mutex mtx; ++ status_t execute_forward(const exec_ctx_t &ctx) const; ++ ++ const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } ++}; // acl_matmul_t ++ ++} // namespace matmul ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl ++ ++#endif +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.cpp b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +new file mode 100644 +index 000000000..bf35ef83c +--- /dev/null ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.cpp +@@ -0,0 +1,174 @@ ++/******************************************************************************* ++* Copyright 2021 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#include "cpu/matmul/matmul_utils.hpp" ++ ++#include "cpu/aarch64/matmul/acl_matmul_utils.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++namespace matmul { ++ ++using namespace dnnl::impl::status; ++using namespace dnnl::impl::utils; ++using namespace dnnl::impl::cpu::matmul; ++using namespace prop_kind; ++using namespace format_tag; ++using namespace dnnl::impl::alg_kind; ++ ++namespace acl_matmul_utils { ++ ++status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, ++ memory_desc_t &wei_md, memory_desc_t &dst_md, memory_desc_t &bias_md, ++ const matmul_desc_t &md, const primitive_attr_t &attr) { ++ ++ const memory_desc_wrapper src_d(&src_md); ++ const memory_desc_wrapper wei_d(&wei_md); ++ const memory_desc_wrapper dst_d(&dst_md); ++ const memory_desc_wrapper bia_d(&bias_md); ++ ++ matmul_helper_t helper(src_d, wei_d, dst_d); ++ const dim_t M = helper.M(); ++ const dim_t N = helper.N(); ++ const dim_t K = helper.K(); ++ const dim_t batch = helper.batch(); ++ ++ // ACL does not support bias ++ amp.with_bias = md.bias_desc.format_kind != format_kind::undef; ++ if (amp.with_bias) return status::unimplemented; ++ ++ auto src_tag = memory_desc_matches_one_of_tag( ++ src_md, abcd, abdc, abc, acb, ab, ba); ++ auto wei_tag = memory_desc_matches_one_of_tag( ++ wei_md, abcd, abdc, abc, acb, ab, ba); ++ auto dst_tag = memory_desc_matches_one_of_tag( ++ dst_md, abcd, abdc, abc, acb, ab, ba); ++ if (one_of(format_tag::undef, src_tag, wei_tag, dst_tag)) { ++ return status::unimplemented; ++ } ++ amp.is_transA = helper.transA() == 'T'; ++ amp.is_transB = helper.transB() == 'T'; ++ if (amp.is_transA) ++ amp.src_acc_info = arm_compute::TensorInfo( ++ arm_compute::TensorShape(M, K, 1, batch), 1, ++ arm_compute::DataType::F32); ++ if (amp.is_transB) ++ amp.wei_acc_info ++ = arm_compute::TensorInfo(arm_compute::TensorShape(K, N, batch), ++ 1, arm_compute::DataType::F32); ++ ++ amp.src_info ++ = arm_compute::TensorInfo(arm_compute::TensorShape(K, M, 1, batch), ++ 1, arm_compute::DataType::F32); ++ amp.wei_info ++ = arm_compute::TensorInfo(arm_compute::TensorShape(N, K, batch), 1, ++ arm_compute::DataType::F32); ++ amp.dst_info ++ = arm_compute::TensorInfo(arm_compute::TensorShape(N, M, 1, batch), ++ 1, arm_compute::DataType::F32); ++ ++ // Fused ReLU activation ++ amp.gemm_info.set_activation_info(get_acl_act(attr)); ++ // Set alpha (output scaling) ++ amp.alpha = attr.output_scales_.scales_[0]; ++ // Validate ACL transpose ++ if (amp.is_transA) { ++ auto acl_transA_st = arm_compute::NETranspose::validate( ++ &.src_acc_info, &.src_info); ++ if (acl_transA_st.error_code() != arm_compute::ErrorCode::OK) { ++ printf("%s\n", acl_transA_st.error_description().c_str()); ++ return status::unimplemented; ++ } ++ } ++ if (amp.is_transB) { ++ auto acl_transB_st = arm_compute::NETranspose::validate( ++ &.wei_acc_info, &.wei_info); ++ if (acl_transB_st.error_code() != arm_compute::ErrorCode::OK) { ++ printf("%s\n", acl_transB_st.error_description().c_str()); ++ return status::unimplemented; ++ } ++ } ++ // Validate ACL GEMM ++ auto acl_st = arm_compute::NEGEMM::validate(&.src_info, &.wei_info, ++ nullptr, &.dst_info, amp.alpha, 0.0f, amp.gemm_info); ++ if (acl_st.error_code() != arm_compute::ErrorCode::OK) { ++ printf("%s\n", acl_st.error_description().c_str()); ++ return status::unimplemented; ++ } ++ ++ return status::success; ++} ++ ++arm_compute::ActivationLayerInfo get_acl_act(const primitive_attr_t &attr) { ++ const auto &post_ops = attr.post_ops_; ++ const int entry_idx = post_ops.find(primitive_kind::eltwise); ++ if (entry_idx == -1) { return arm_compute::ActivationLayerInfo(); } ++ ++ const auto eltwise_alg = post_ops.entry_[entry_idx].eltwise.alg; ++ float alpha = post_ops.entry_[entry_idx].eltwise.alpha; ++ float beta = post_ops.entry_[entry_idx].eltwise.beta; ++ ++ using acl_act_t = arm_compute::ActivationLayerInfo::ActivationFunction; ++ acl_act_t acl_act_alg; ++ switch (eltwise_alg) { ++ case eltwise_relu: ++ // oneDNN defines RELU: f(x) = (x > 0) ? x : a*x ++ // Compute Library defines LEAKY_RELU: f(x) = (x > 0) ? x : a*x ++ // whilst Compute Library RELU is defined as: f(x) = max(0,x) ++ if (alpha == 0) { ++ acl_act_alg = acl_act_t::RELU; ++ } else { ++ acl_act_alg = acl_act_t::LEAKY_RELU; ++ } ++ break; ++ case eltwise_tanh: ++ // oneDNN defines TANH activation as: f(x) = tanh(x) ++ // Compute Library defines TANH activation as: f(x) = a*tanh(b*x) ++ // Setting a=b=1 makes the two equivalent ++ alpha = 1.f; ++ beta = 1.f; ++ acl_act_alg = acl_act_t::TANH; ++ break; ++ case eltwise_elu: acl_act_alg = acl_act_t::ELU; break; ++ case eltwise_square: acl_act_alg = acl_act_t::SQUARE; break; ++ case eltwise_abs: acl_act_alg = acl_act_t::ABS; break; ++ case eltwise_sqrt: acl_act_alg = acl_act_t::SQRT; break; ++ case eltwise_linear: acl_act_alg = acl_act_t::LINEAR; break; ++ case eltwise_bounded_relu: acl_act_alg = acl_act_t::BOUNDED_RELU; break; ++ case eltwise_soft_relu: acl_act_alg = acl_act_t::SOFT_RELU; break; ++ case eltwise_logistic: acl_act_alg = acl_act_t::LOGISTIC; break; ++ default: return arm_compute::ActivationLayerInfo(); ++ } ++ ++ return arm_compute::ActivationLayerInfo(acl_act_alg, alpha, beta); ++} ++ ++bool acl_act_ok(alg_kind_t eltwise_activation) { ++ return utils::one_of(eltwise_activation, eltwise_relu, eltwise_tanh, ++ eltwise_elu, eltwise_square, eltwise_abs, eltwise_sqrt, ++ eltwise_linear, eltwise_bounded_relu, eltwise_soft_relu, ++ eltwise_logistic); ++} ++ ++} // namespace acl_matmul_utils ++ ++} // namespace matmul ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl +\ No newline at end of file +diff --git a/src/cpu/aarch64/matmul/acl_matmul_utils.hpp b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +new file mode 100644 +index 000000000..1411dc4f4 +--- /dev/null ++++ b/src/cpu/aarch64/matmul/acl_matmul_utils.hpp +@@ -0,0 +1,70 @@ ++/******************************************************************************* ++* Copyright 2021 Arm Ltd. and affiliates ++* ++* Licensed under the Apache License, Version 2.0 (the "License"); ++* you may not use this file except in compliance with the License. ++* You may obtain a copy of the License at ++* ++* http://www.apache.org/licenses/LICENSE-2.0 ++* ++* Unless required by applicable law or agreed to in writing, software ++* distributed under the License is distributed on an "AS IS" BASIS, ++* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++* See the License for the specific language governing permissions and ++* limitations under the License. ++*******************************************************************************/ ++ ++#ifndef CPU_AARCH64_ACL_MATMUL_UTILS_HPP ++#define CPU_AARCH64_ACL_MATMUL_UTILS_HPP ++ ++#include "cpu/matmul/cpu_matmul_pd.hpp" ++ ++#include "cpu/aarch64/acl_utils.hpp" ++ ++namespace dnnl { ++namespace impl { ++namespace cpu { ++namespace aarch64 { ++namespace matmul { ++ ++struct acl_matmul_obj_t { ++ arm_compute::NEGEMM gemm; ++ arm_compute::NETranspose transA; ++ arm_compute::NETranspose transB; ++ arm_compute::Tensor src_tensor; ++ arm_compute::Tensor src_acc_tensor; ++ arm_compute::Tensor wei_tensor; ++ arm_compute::Tensor wei_acc_tensor; ++ arm_compute::Tensor dst_tensor; ++}; ++ ++struct acl_matmul_conf_t { ++ bool with_bias; ++ bool is_transA; ++ bool is_transB; ++ arm_compute::TensorInfo src_info; ++ arm_compute::TensorInfo src_acc_info; ++ arm_compute::TensorInfo wei_info; ++ arm_compute::TensorInfo wei_acc_info; ++ arm_compute::TensorInfo dst_info; ++ arm_compute::GEMMInfo gemm_info; ++ float alpha; ++}; ++ ++namespace acl_matmul_utils { ++ ++status_t init_conf_matmul(acl_matmul_conf_t &, memory_desc_t &src_md, ++ memory_desc_t &wei_md, memory_desc_t &dst_md, memory_desc_t &bias_md, ++ const matmul_desc_t &md, const primitive_attr_t &attr); ++ ++arm_compute::ActivationLayerInfo get_acl_act(const primitive_attr_t &attr); ++bool acl_act_ok(alg_kind_t eltwise_activation); ++} // namespace acl_matmul_utils ++ ++} // namespace matmul ++} // namespace aarch64 ++} // namespace cpu ++} // namespace impl ++} // namespace dnnl ++ ++#endif // CPU_AARCH64_ACL_MATMUL_UTILS_HPP +\ No newline at end of file +diff --git a/src/cpu/matmul/cpu_matmul_list.cpp b/src/cpu/matmul/cpu_matmul_list.cpp +index cb0b87f8f..8520aef8b 100644 +--- a/src/cpu/matmul/cpu_matmul_list.cpp ++++ b/src/cpu/matmul/cpu_matmul_list.cpp +@@ -1,5 +1,6 @@ + /******************************************************************************* + * Copyright 2019-2021 Intel Corporation ++* Copyright 2021 Arm Ltd. and affiliates + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. +@@ -25,6 +26,11 @@ + #include "cpu/x64/matmul/brgemm_matmul.hpp" + using namespace dnnl::impl::cpu::x64::matmul; + using namespace dnnl::impl::cpu::x64; ++#elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL ++#include "cpu/aarch64/matmul/acl_matmul.hpp" ++using namespace dnnl::impl::cpu::aarch64::matmul; ++using namespace dnnl::impl::cpu::aarch64; ++ + #endif + + namespace dnnl { +@@ -37,6 +43,7 @@ using namespace dnnl::impl::cpu::matmul; + + // clang-format off + const impl_list_item_t impl_list[] = { ++ CPU_INSTANCE_AARCH64_ACL(acl_matmul_t) + CPU_INSTANCE(gemm_f32_matmul_t) + CPU_INSTANCE_X64(brgemm_matmul_t) + CPU_INSTANCE(gemm_bf16_matmul_t)