Skip to content

cpu: aarch64: brgemm: Add support for int8 in brgemm kernel #3414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kasturedeeksha
Copy link
Contributor

Description

This PR extends the BRGEMM (Batch-Reduce General Matrix Multiplication) kernel to support additional INT8 data types, enabling broader applicability for low-precision computations, particularly in deep learning workloads.

Supported Data Type Tags
The following source:weight:destination (src:wei:dst) combinations are now supported:

  • s8:s8:f32
  • u8:u8:f32
  • u8:s8:f32

Checklist

General

  • Do all unit and benchdnn tests (make test and make test_benchdnn_*) pass locally for each commit?
  1. make test output
98% tests passed, 4 tests failed out of 224
 
Total Test time (real) = 1058.22 sec
 
The following tests FAILED:
        172 - test_graph_unit_dnnl_large_partition_cpu (Failed)
        195 - test_benchdnn_modeC_binary_ci_cpu (Failed)
        196 - test_benchdnn_modeC_binary_different_dt_ci_cpu (Failed)
        204 - test_benchdnn_modeC_graph_ci_cpu (Failed)

Output is same before and after the code changes.

  1. brgemm_test_all output
    command used :

./benchdnn --brgemm --batch=inputs/brgemm/test_brgemm_all

Before

tests:660480 passed:18496 skipped:641984 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 24.50s; create_pd: 0.00s (0%); create_prim: 0.00s (0%); fill: 7.30s (30%); execute: 0.00s (0%); compute_ref: 4.30s (18%); compare: 5.34s (22%);

After

tests:660480 passed:20480 skipped:640000 mistrusted:0 unimplemented:0 invalid_arguments:0 failed:0 listed:0
total: 23.11s; create_pd: 0.00s (0%); create_prim: 0.00s (0%); fill: 6.83s (30%); execute: 0.00s (0%); compute_ref: 3.82s (17%); compare: 4.40s (19%);
  • Have you formatted the code using clang-format?
    Yes

@kasturedeeksha kasturedeeksha requested a review from a team as a code owner June 11, 2025 08:50
@github-actions github-actions bot added the platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64 label Jun 11, 2025
Copy link
Contributor

@jondea jondea left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks generally great, thank you for this contribution.

Do you have a rough idea of the performance. For example, compared to F32, is it ~4x faster?

Also, what's the general idea of this algorithm? SDOT is more complicated than FMLA in that it reduces the elements from Int8 to S32. Do we handle this reduction using some kind of blocking? The other way is operating on the transpose and using FADDV, but I can't see that here. In short, it would be great to have a short overview of how this kernel works.

else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8)
usdot(v1.s, v_a.b, v_b.b);
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8)
assert(!"unsupported\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just swap v_a and v_b in this case?

@@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() {
add(reg_aux_B, reg_aux_B, reg_b_offset);
}

void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v2, ZReg v3) {
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v_b, ZReg v_a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1,v2 and v3 is unclear, although it is consistent. v1,v_b and v_a feels like the right direction, but is more confusing.

  • Why is b first?
  • What is v1 in this case, if a and b reference the A and B matrices, then it probably makes sense to give v1 a name with a similar theme, maybe v_acc?

@@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
const auto bd_by_load_bytes
= (bd >= bd_e - rows_by_load_bytes
|| brg.brgattr.wary_A_k_tail_read);
broadcast(bcst(), A_offset(bd, rd),
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
int should_broadcast = static_cast<int>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? We have a variable called should_broadcast but then we call broadcast even if it is not true, are they two different kinds of broadcasting? If so it would be great to make that clear in the variable/function names

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
platform:cpu-aarch64 Codeowner: @oneapi-src/onednn-cpu-aarch64
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants