Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aarch64: cherrypick onednn pr#1768 to improve torch.compile perf #1716

Merged
merged 1 commit into from Mar 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -0,0 +1,106 @@
From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001
From: Sunita Nadampalli <nadampal@amazon.com>
Date: Wed, 6 Mar 2024 00:37:30 +0000
Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product
primitive

---
src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++----
src/cpu/cpu_inner_product_list.cpp | 9 ++++++++
2 files changed, 37 insertions(+), 5 deletions(-)

diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp
index a2be164f09..762d5d4896 100644
--- a/src/cpu/aarch64/acl_inner_product.hpp
+++ b/src/cpu/aarch64/acl_inner_product.hpp
@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t {

status_t init(engine_t *engine) {
using namespace data_type;
+ const format_kind_t weights_format_kind_received
+ = weights_md_.format_kind;
+
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f16);
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef)
&& attr()->has_default_values(
primitive_attr_t::skip_mask_t::post_ops, f32);
+
+ const bool is_fp32_bf16_ok
+ = expect_data_types(f32, bf16, f32, f32, undef)
+ && attr()->has_default_values(
+ primitive_attr_t::skip_mask_t::post_ops, f32);
+
+ const bool is_weights_md_format_ok
+ = utils::one_of(weights_format_kind_received,
+ format_kind::any, format_kind::blocked);
+
const bool ok = is_fwd() && !has_zero_dim_memory()
- && utils::one_of(true, is_fp16_ok, is_fp32_ok)
- && weights_md_.format_kind == format_kind::any
- && set_default_params() == status::success;
+ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok)
+ && is_weights_md_format_ok
+ && set_default_params(true) == status::success;

if (!ok) return status::unimplemented;

- CHECK(init_conf_ip(engine));
+ CHECK(init_conf_ip(engine, weights_format_kind_received));

return status::success;
}
@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t {

acl_post_ops_t post_ops;

- status_t init_conf_ip(engine_t *engine) {
+ status_t init_conf_ip(
+ engine_t *engine, format_kind_t weights_format_kind_received) {

ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims,
"source and weights dimensions must match");
@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t {
return status::unimplemented;
}

+ const memory_desc_t weights_md_received = weights_md_;
acl_utils::reorder_to_weight_format(aip.wei_tensor_info,
weights_md_, expected_weight_format, inner_dim, o_dim,
remaining_dims, {});

+ ACL_CHECK_SUPPORT(
+ (weights_format_kind_received == format_kind::blocked)
+ && !(dnnl_memory_desc_equal(
+ &weights_md_received, &weights_md_)),
+ "specific blocked format not supported by ACL, use "
+ "format_kind_t::any to find a supported blocked format for "
+ "your platform");
+
// clang-format off

// Validate fully connected layer manually to check for return status
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp
index fdd7b17769..1f59547304 100644
--- a/src/cpu/cpu_inner_product_list.cpp
+++ b/src/cpu/cpu_inner_product_list.cpp
@@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE(ref_inner_product_fwd_t)
nullptr,
}},
+ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load
+ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided.
+ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is
+ * in fp32 and weights are in bf16
+ */
+ {{forward, f32, bf16, f32}, {
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t)
+ nullptr,
+ }},
{{backward_data, f32, f32, f32}, REG_BWD_PK({
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>)
--
2.34.1