-
Notifications
You must be signed in to change notification settings - Fork 1k
cpu: aarch64: brgemm: Add support for int8 in brgemm kernel #3414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
cpu: aarch64: brgemm: Add support for int8 in brgemm kernel #3414
Conversation
9f565eb
to
6b6d60e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks generally great, thank you for this contribution.
Do you have a rough idea of the performance. For example, compared to F32, is it ~4x faster?
Also, what's the general idea of this algorithm? SDOT
is more complicated than FMLA
in that it reduces the elements from Int8 to S32. Do we handle this reduction using some kind of blocking? The other way is operating on the transpose and using FADDV, but I can't see that here. In short, it would be great to have a short overview of how this kernel works.
else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) | ||
usdot(v1.s, v_a.b, v_b.b); | ||
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) | ||
assert(!"unsupported\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not just swap v_a
and v_b
in this case?
@@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() { | |||
add(reg_aux_B, reg_aux_B, reg_b_offset); | |||
} | |||
|
|||
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v2, ZReg v3) { | |||
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v_b, ZReg v_a) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v1
,v2
and v3
is unclear, although it is consistent. v1
,v_b
and v_a
feels like the right direction, but is more confusing.
- Why is
b
first? - What is
v1
in this case, ifa
andb
reference the A and B matrices, then it probably makes sense to givev1
a name with a similar theme, maybev_acc
?
@@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |||
const auto bd_by_load_bytes | |||
= (bd >= bd_e - rows_by_load_bytes | |||
|| brg.brgattr.wary_A_k_tail_read); | |||
broadcast(bcst(), A_offset(bd, rd), | |||
have_to_load_bytes && bd_by_load_bytes, brg.dt_a); | |||
int should_broadcast = static_cast<int>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean? We have a variable called should_broadcast
but then we call broadcast
even if it is not true, are they two different kinds of broadcasting? If so it would be great to make that clear in the variable/function names
Description
This PR extends the BRGEMM (Batch-Reduce General Matrix Multiplication) kernel to support additional INT8 data types, enabling broader applicability for low-precision computations, particularly in deep learning workloads.
Supported Data Type Tags
The following source:weight:destination (src:wei:dst) combinations are now supported:
Checklist
General
make test
andmake test_benchdnn_*
) pass locally for each commit?Output is same before and after the code changes.
command used :
./benchdnn --brgemm --batch=inputs/brgemm/test_brgemm_all
Before
After
Yes