Skip to content

Commit

Permalink
aarch64: cherrypick onednn pr#1768 to improve torch.compile perf (#1716)
Browse files Browse the repository at this point in the history
this improves the bert base torch.compile perf by 5.8x on
AWS c7g instance.
  • Loading branch information
snadampal committed Mar 12, 2024
1 parent fa3dbb0 commit c084122
Showing 1 changed file with 106 additions and 0 deletions.
@@ -0,0 +1,106 @@
From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001
From: Sunita Nadampalli <nadampal@amazon.com>
Date: Wed, 6 Mar 2024 00:37:30 +0000
Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product
primitive

---
src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++----
src/cpu/cpu_inner_product_list.cpp | 9 ++++++++
2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
index a2be164f09..762d5d4896 100644
--- a/src/cpu/aarch64/acl_inner_product.hpp
+++ b/src/cpu/aarch64/acl_inner_product.hpp
@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t {

status_t init(engine_t *engine) {
using namespace data_type;
+ const format_kind_t weights_format_kind_received
+ = weights_md_.format_kind;
+
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f16);
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
+
+ const bool is_fp32_bf16_ok
+ = expect_data_types(f32, bf16, f32, f32, undef)
+ && attr()->has_default_values(
+ primitive_attr_t::skip_mask_t::post_ops, f32);
+
+ const bool is_weights_md_format_ok
+ = utils::one_of(weights_format_kind_received,
+ format_kind::any, format_kind::blocked);
+
const bool ok = is_fwd() && !has_zero_dim_memory()
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
- && weights_md_.format_kind == format_kind::any
- && set_default_params() == status::success;
+ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
+ && is_weights_md_format_ok
+ && set_default_params(true) == status::success;

if (!ok) return status::unimplemented;

- CHECK(init_conf_ip(engine));
+ CHECK(init_conf_ip(engine, weights_format_kind_received));

return status::success;
}
@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t {

acl_post_ops_t post_ops;

- status_t init_conf_ip(engine_t *engine) {
+ status_t init_conf_ip(
+ engine_t *engine, format_kind_t weights_format_kind_received) {

ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims,
"source and weights dimensions must match");
@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t {
return status::unimplemented;
}

+ const memory_desc_t weights_md_received = weights_md_;
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
weights_md_, expected_weight_format, inner_dim, o_dim,
remaining_dims, {});

+ ACL_CHECK_SUPPORT(
+ (weights_format_kind_received == format_kind::blocked)
+ && !(dnnl_memory_desc_equal(
+ &weights_md_received, &weights_md_)),
+ "specific blocked format not supported by ACL, use "
+ "format_kind_t::any to find a supported blocked format for "
+ "your platform");
+
// clang-format off

// Validate fully connected layer manually to check for return status
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
index fdd7b17769..1f59547304 100644
--- a/src/cpu/cpu_inner_product_list.cpp
+++ b/src/cpu/cpu_inner_product_list.cpp
@@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
+ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load
+ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
+ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is
+ * in fp32 and weights are in bf16
+ */
+ {{forward, f32, bf16, f32}, {
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
+ nullptr,
+ }},
{{backward_data, f32, f32, f32}, REG_BWD_PK({
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
--
2.34.1

0 comments on commit c084122

Please sign in to comment.