-
Notifications
You must be signed in to change notification settings - Fork 1k
cpu: aarch64: brgemm: Add support for int8 in brgemm kernel #3414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
/******************************************************************************* | ||
* Copyright 2021-2023 Intel Corporation | ||
* Copyright 2024 FUJITSU LIMITED | ||
* Copyright 2024-2025 FUJITSU LIMITED | ||
* Copyright 2024 Arm Ltd. and affiliates | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
|
@@ -509,12 +509,20 @@ void jit_brgemm_kernel_t::cvt2ps(data_type_t type_in, const ZReg zmm_in, | |
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s); | ||
break; | ||
case data_type::bf16: assert(!"unsupported data type\n"); break; | ||
case data_type::s8: assert(!"unsupported data type\n"); break; | ||
case data_type::u8: assert(!"unsupported data type\n"); break; | ||
case data_type::s8: | ||
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1); | ||
sxtb(z_tmp_1().d, mask / T_m, z_tmp_1().d); | ||
if (store) // Merging | ||
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s); | ||
break; | ||
case data_type::u8: | ||
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1); | ||
uxtb(z_tmp_1().s, mask / T_m, z_tmp_1().s); | ||
if (store) // Merging | ||
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s); | ||
break; | ||
default: assert(!"unsupported data type"); | ||
} | ||
if (!one_of(type_in, data_type::f32, data_type::bf16)) | ||
assert(!"unsupported data type\n"); | ||
} | ||
|
||
void jit_brgemm_kernel_t::advance_ldb_post_op_regs() { | ||
|
@@ -961,7 +969,9 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( | |
} | ||
for (int bd = 0; bd < bd_block; bd++) { | ||
auto zmm = accm(ld_block2, bd, ld); | ||
if (dq2ps_required) { scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s); } | ||
if (dq2ps_required && !brg.with_scales) { | ||
scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s); | ||
} | ||
if (brg.with_bias) { fadd(zmm.s, zmm.s, zmm_bias.s); } | ||
} | ||
} | ||
|
@@ -989,7 +999,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops( | |
auto vmm_zp_c = z_tmp_1(); | ||
if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) { | ||
add_imm(X_DEFAULT_ADDR, reg_aux_zp_c_values, 0, X_TMP_0); | ||
ldr(z_tmp_2(), ptr(X_DEFAULT_ADDR)); | ||
ld1rw(z_tmp_2().s, P_ALL_ONE, ptr(X_DEFAULT_ADDR)); | ||
scvtf(vmm_zp_c.s, k_mask / T_m, z_tmp_2().s); | ||
} | ||
for (int ld = 0; ld < ld_block2; ld++) { | ||
|
@@ -1048,7 +1058,7 @@ void jit_brgemm_kernel_t::apply_compensation( | |
if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { | ||
auto vmm_zp_a_val = z_tmp_2(); | ||
add_imm(X_DEFAULT_ADDR, X_SP, reg_zp_a_val_offs_, X_TMP_0); | ||
ldr(reg_zp_a_val, ptr(X_DEFAULT_ADDR)); | ||
add_imm(reg_zp_a_val, X_SP, reg_zp_a_val_offs_, X_TMP_0); | ||
ldr(W_TMP_0, ptr(reg_zp_a_val)); | ||
dup(vmm_zp_a_val.s, W_TMP_0); | ||
|
||
|
@@ -1096,13 +1106,15 @@ void jit_brgemm_kernel_t::apply_compensation( | |
const bool is_tail = is_ld_tail && ld + 1 == ld_block2; | ||
if (IMPLICATION(is_tail, is_superset(brg.isa_impl, sve_512))) { | ||
const auto mask = is_tail ? k_mask : P_ALL_ONE; | ||
ld1w(vmm_comp.s, mask / T_z, | ||
ptr(reg_aux_compensation, comp_offset)); | ||
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset, | ||
X_TMP_1); | ||
ld1w(vmm_comp.s, mask / T_z, ptr(X_DEFAULT_ADDR)); | ||
} else { | ||
not_(P_TMP.b, P_ALL_ONE, P_NOT_256.b); | ||
cmplt(P_TMP.s, P_TMP / T_z, z_tail_mask().s, 0); | ||
ld1w(vmm_comp.s, P_TMP / T_z, | ||
ptr(reg_aux_compensation, comp_offset)); | ||
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset, | ||
X_TMP_1); | ||
ld1w(vmm_comp.s, P_TMP / T_z, ptr(X_DEFAULT_ADDR)); | ||
} | ||
|
||
for (int bd = 0; bd < bd_block; bd++) { | ||
|
@@ -1154,7 +1166,11 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, bool is_bdb_tail, | |
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block; | ||
|
||
if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) { | ||
assert(!"unsupported\n"); | ||
Label label_store_without_comp; | ||
cmp_imm(reg_do_comp, 0, X_TMP_0); | ||
b(EQ, label_store_without_comp); | ||
apply_compensation(bd_block, ld_block2, is_ld_tail); | ||
L_aligned(label_store_without_comp); | ||
} | ||
|
||
if (need_to_apply_alpha_beta) | ||
|
@@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() { | |
add(reg_aux_B, reg_aux_B, reg_b_offset); | ||
} | ||
|
||
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v2, ZReg v3) { | ||
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v_b, ZReg v_a) { | ||
if (brg.is_f32) { | ||
fmla(v1.s, P_ALL_ONE / T_m, v2.s, v3.s); | ||
fmla(v1.s, P_ALL_ONE / T_m, v_b.s, v_a.s); | ||
} else if (brg.is_bf16) | ||
assert(!"unsupported\n"); | ||
else if (brg.is_int8) | ||
assert(!"unsupported\n"); | ||
else | ||
else if (brg.is_int8 && isa_has_s8s8(brg.isa_impl)) { | ||
if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8) | ||
udot(v1.s, v_b.b, v_a.b); | ||
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8) | ||
sdot(v1.s, v_a.b, v_b.b); | ||
else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8) | ||
usdot(v1.s, v_a.b, v_b.b); | ||
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8) | ||
assert(!"unsupported\n"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not just swap |
||
} else | ||
assert(!"unsupported\n"); | ||
} | ||
|
||
|
@@ -1294,7 +1317,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, | |
if (brg.zp_type_a != brgemm_broadcast_t::none) { | ||
eor(vmm_tmp.d, vmm_tmp.d, vmm_tmp.d); | ||
dot_product(vmm_tmp, vmm_load, z_one_bytes()); | ||
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_inp_shift().s); | ||
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_zp_a_shift().s); | ||
|
||
for (int bd = bd_b; bd < bd_e; bd++) { | ||
auto vmm = accm(ld_block2, bd, ld); | ||
|
@@ -1321,7 +1344,8 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b, | |
for (int ld = 0; ld < ld_block2; ++ld) { | ||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2; | ||
const auto mask = is_tail ? ld_tail_mask : P_ALL_ONE; | ||
ld1w(load().s, mask / T_z, ptr(reg_aux_B, B_offset(ld, rd))); | ||
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0); | ||
ld1w(load().s, mask / T_z, ptr(X_DEFAULT_ADDR)); | ||
|
||
if (brg.req_cal_comp_pads) { | ||
compensation_padding(load(), bcst(), ld, bd_b, bd_e); | ||
|
@@ -1348,8 +1372,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
|
||
int rd_loop = 0, rd_tail_size = 0; | ||
if (is_rd_tail) { | ||
rd_tail_size = brg.rdb_tail % brg.rd_step; | ||
if (brg.is_bf16 || brg.is_int8) { | ||
assert(!"unsupported\n"); | ||
rd_loop = (rd_tail_size != 0) | ||
? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step | ||
: brg.rdb_tail; | ||
} else | ||
rd_loop = brg.rdb_tail; | ||
} else | ||
|
@@ -1360,9 +1387,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
if (is_tail) { | ||
eor(z1.d, z1.d, z1.d); | ||
auto xmm_tmp = z_tmp_1(); | ||
add_imm(X_DEFAULT_ADDR, reg_aux_A, rd_tail_size * brg.typesize_A, | ||
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset * brg.typesize_A, | ||
X_TMP_0); | ||
set_preg(P_TMP.b, offset); | ||
set_preg(P_TMP.b, rd_tail_size, X_TMP_0, X_TMP_1); | ||
ld1b(xmm_tmp.b, P_TMP / T_z, ptr(X_DEFAULT_ADDR)); | ||
dup(z1.s, xmm_tmp.s[0]); | ||
} else { | ||
|
@@ -1377,7 +1404,8 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
} else if (dt == data_type::bf16) { | ||
assert(!"unsupported\n"); | ||
} else if (one_of(dt, data_type::s8, data_type::u8)) { | ||
assert(!"unsupported\n"); | ||
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset, X_TMP_0); | ||
ld1rw(z1.s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); | ||
} else if (dt == data_type::f16) { | ||
assert(!"unsupported\n"); | ||
} | ||
|
@@ -1389,7 +1417,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
const bool comp_vpad = vpad != 0 | ||
&& (brg.req_s8s8_compensation | ||
|| brg.zp_type_a != brgemm_broadcast_t::none); | ||
if (brg.req_cal_comp_pads || comp_vpad) assert(!"unsupported\n"); | ||
if (brg.req_cal_comp_pads || comp_vpad) | ||
compute_int8_compensation( | ||
rd_loop, bd_b, bd_e, bd_block, ld_block2, is_ld_tail, vpad); | ||
|
||
bool maybe_load_bytes | ||
= (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read) | ||
|
@@ -1407,17 +1437,17 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
have_to_load_bytes && bd_by_load_bytes, brg.dt_a); | ||
} | ||
for (int ld = 0; ld < ld_block2; ld++) { | ||
const auto addr = ptr(reg_aux_B, B_offset(ld, rd)); | ||
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE; | ||
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0); | ||
if (brg.dt_b == data_type::f16) { | ||
assert(!"unsupported\n"); | ||
} else if (brg.dt_b == data_type::bf16 | ||
&& brg.isa_impl == sve_256) { | ||
assert(!"unsupported\n"); | ||
} else if (is_ld_tail) { | ||
ld1w(load().s, ld_tail_mask / T_z, addr); | ||
ld1w(load().s, ld_tail_mask / T_z, ptr(X_DEFAULT_ADDR)); | ||
} else { | ||
ld1w(load().s, P_ALL_ONE / T_z, addr); | ||
ld1w(load().s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR)); | ||
} | ||
for (int bd = bd_b; bd < bd_e; bd++) { | ||
auto vmm = accm(ld_block2, bd, ld); | ||
|
@@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2, | |
const auto bd_by_load_bytes | ||
= (bd >= bd_e - rows_by_load_bytes | ||
|| brg.brgattr.wary_A_k_tail_read); | ||
broadcast(bcst(), A_offset(bd, rd), | ||
have_to_load_bytes && bd_by_load_bytes, brg.dt_a); | ||
int should_broadcast = static_cast<int>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does this mean? We have a variable called |
||
have_to_load_bytes && bd_by_load_bytes); | ||
broadcast(bcst(), A_offset(bd, rd), should_broadcast, | ||
brg.dt_a); | ||
} | ||
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced. | ||
for (int ld = 0; ld < ld_block2; ld++) { | ||
|
@@ -1561,7 +1593,12 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, | |
|
||
if (brg.req_s8s8_compensation) { assert(!"unsupported\n"); } | ||
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) { | ||
assert(!"unsupported\n"); | ||
str(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_)); | ||
const auto reg32_scratch = WReg(reg_zp_a_input_shift.getIdx()); | ||
mov(z_one_bytes().b, 1); | ||
ldr(reg32_scratch, ptr(X_SP, reg_zp_a_val_offs_)); | ||
dup(z_zp_a_shift().s, reg32_scratch); | ||
ldr(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_)); | ||
} | ||
|
||
if (brg.brgattr.max_bs > 1) { mov(reg_BS_loop, reg_BS); } | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v1
,v2
andv3
is unclear, although it is consistent.v1
,v_b
andv_a
feels like the right direction, but is more confusing.b
first?v1
in this case, ifa
andb
reference the A and B matrices, then it probably makes sense to givev1
a name with a similar theme, maybev_acc
?