Skip to content

cpu: aarch64: brgemm: Add support for int8 in brgemm kernel #3414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions src/cpu/aarch64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2025 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2023-2025 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -229,12 +229,11 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
brg->LDD = LDD;
const auto dt_d = dst_md->data_type;

if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
data_type::f32))
&& (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
data_type::s32, data_type::f32, data_type::bf16)))
return status::unimplemented;
if (brg->is_int8) {
if ((brg->dt_a == data_type::s8 && brg->dt_b == data_type::u8)
|| (dt_bias != data_type::f32) || (dt_d != data_type::f32))
return status::unimplemented;
}
if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
&& (!one_of(dt_d, data_type::bf16, data_type::f32))
&& (!one_of(dt_bias, data_type::undef, data_type::bf16,
Expand All @@ -248,7 +247,7 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
brg->dt_d = dt_d;
brg->typesize_D = types::data_type_size(brg->dt_d);

if (brg->is_int8 || (brg->dt_d == bf16)) return status::unimplemented;
if (brg->dt_d == bf16) return status::unimplemented;

if (!brg->attr) return status::success;

Expand Down
9 changes: 4 additions & 5 deletions src/cpu/aarch64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2022-2023 Intel Corporation
* Copyright 2023-2024 FUJITSU LIMITED
* Copyright 2023-2025 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -325,7 +325,8 @@ status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brg->has_int8_vnni = true;

set_brg_vmm(brg); // TODO: Investigate if it is really needed here.
brg->req_s8s8_compensation = brg->is_int8 && brg->dt_a == data_type::s8;
brg->req_s8s8_compensation = (brg->is_int8 && (brg->dt_a == data_type::s8)
&& !isa_has_s8s8(brg->isa_impl));

brg->LDA = (brg->is_row_major()) ? static_cast<int>(LDA)
: static_cast<int>(LDB);
Expand All @@ -344,9 +345,7 @@ status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
brg->bdb2 = 0;
brg->bdb2_tail = 0;

const bool is_b_in_vnni_format = false;
brg->ld_step
= is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1;
brg->ld_step = data_type_vnni_granularity(brg->dt_b);

const bool has_no_vnni_compute_instruction = false;
brg->rd_step = has_no_vnni_compute_instruction
Expand Down
99 changes: 68 additions & 31 deletions src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2021-2023 Intel Corporation
* Copyright 2024 FUJITSU LIMITED
* Copyright 2024-2025 FUJITSU LIMITED
* Copyright 2024 Arm Ltd. and affiliates
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -509,12 +509,20 @@ void jit_brgemm_kernel_t::cvt2ps(data_type_t type_in, const ZReg zmm_in,
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
break;
case data_type::bf16: assert(!"unsupported data type\n"); break;
case data_type::s8: assert(!"unsupported data type\n"); break;
case data_type::u8: assert(!"unsupported data type\n"); break;
case data_type::s8:
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1);
sxtb(z_tmp_1().d, mask / T_m, z_tmp_1().d);
if (store) // Merging
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
break;
case data_type::u8:
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1);
uxtb(z_tmp_1().s, mask / T_m, z_tmp_1().s);
if (store) // Merging
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
break;
default: assert(!"unsupported data type");
}
if (!one_of(type_in, data_type::f32, data_type::bf16))
assert(!"unsupported data type\n");
}

void jit_brgemm_kernel_t::advance_ldb_post_op_regs() {
Expand Down Expand Up @@ -961,7 +969,9 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
}
for (int bd = 0; bd < bd_block; bd++) {
auto zmm = accm(ld_block2, bd, ld);
if (dq2ps_required) { scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s); }
if (dq2ps_required && !brg.with_scales) {
scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s);
}
if (brg.with_bias) { fadd(zmm.s, zmm.s, zmm_bias.s); }
}
}
Expand Down Expand Up @@ -989,7 +999,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
auto vmm_zp_c = z_tmp_1();
if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
add_imm(X_DEFAULT_ADDR, reg_aux_zp_c_values, 0, X_TMP_0);
ldr(z_tmp_2(), ptr(X_DEFAULT_ADDR));
ld1rw(z_tmp_2().s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
scvtf(vmm_zp_c.s, k_mask / T_m, z_tmp_2().s);
}
for (int ld = 0; ld < ld_block2; ld++) {
Expand Down Expand Up @@ -1048,7 +1058,7 @@ void jit_brgemm_kernel_t::apply_compensation(
if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
auto vmm_zp_a_val = z_tmp_2();
add_imm(X_DEFAULT_ADDR, X_SP, reg_zp_a_val_offs_, X_TMP_0);
ldr(reg_zp_a_val, ptr(X_DEFAULT_ADDR));
add_imm(reg_zp_a_val, X_SP, reg_zp_a_val_offs_, X_TMP_0);
ldr(W_TMP_0, ptr(reg_zp_a_val));
dup(vmm_zp_a_val.s, W_TMP_0);

Expand Down Expand Up @@ -1096,13 +1106,15 @@ void jit_brgemm_kernel_t::apply_compensation(
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
if (IMPLICATION(is_tail, is_superset(brg.isa_impl, sve_512))) {
const auto mask = is_tail ? k_mask : P_ALL_ONE;
ld1w(vmm_comp.s, mask / T_z,
ptr(reg_aux_compensation, comp_offset));
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
X_TMP_1);
ld1w(vmm_comp.s, mask / T_z, ptr(X_DEFAULT_ADDR));
} else {
not_(P_TMP.b, P_ALL_ONE, P_NOT_256.b);
cmplt(P_TMP.s, P_TMP / T_z, z_tail_mask().s, 0);
ld1w(vmm_comp.s, P_TMP / T_z,
ptr(reg_aux_compensation, comp_offset));
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
X_TMP_1);
ld1w(vmm_comp.s, P_TMP / T_z, ptr(X_DEFAULT_ADDR));
}

for (int bd = 0; bd < bd_block; bd++) {
Expand Down Expand Up @@ -1154,7 +1166,11 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, bool is_bdb_tail,
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;

if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) {
assert(!"unsupported\n");
Label label_store_without_comp;
cmp_imm(reg_do_comp, 0, X_TMP_0);
b(EQ, label_store_without_comp);
apply_compensation(bd_block, ld_block2, is_ld_tail);
L_aligned(label_store_without_comp);
}

if (need_to_apply_alpha_beta)
Expand Down Expand Up @@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() {
add(reg_aux_B, reg_aux_B, reg_b_offset);
}

void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v2, ZReg v3) {
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v_b, ZReg v_a) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v1,v2 and v3 is unclear, although it is consistent. v1,v_b and v_a feels like the right direction, but is more confusing.

  • Why is b first?
  • What is v1 in this case, if a and b reference the A and B matrices, then it probably makes sense to give v1 a name with a similar theme, maybe v_acc?

if (brg.is_f32) {
fmla(v1.s, P_ALL_ONE / T_m, v2.s, v3.s);
fmla(v1.s, P_ALL_ONE / T_m, v_b.s, v_a.s);
} else if (brg.is_bf16)
assert(!"unsupported\n");
else if (brg.is_int8)
assert(!"unsupported\n");
else
else if (brg.is_int8 && isa_has_s8s8(brg.isa_impl)) {
if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8)
udot(v1.s, v_b.b, v_a.b);
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8)
sdot(v1.s, v_a.b, v_b.b);
else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8)
usdot(v1.s, v_a.b, v_b.b);
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8)
assert(!"unsupported\n");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just swap v_a and v_b in this case?

} else
assert(!"unsupported\n");
}

Expand Down Expand Up @@ -1294,7 +1317,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
if (brg.zp_type_a != brgemm_broadcast_t::none) {
eor(vmm_tmp.d, vmm_tmp.d, vmm_tmp.d);
dot_product(vmm_tmp, vmm_load, z_one_bytes());
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_inp_shift().s);
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_zp_a_shift().s);

for (int bd = bd_b; bd < bd_e; bd++) {
auto vmm = accm(ld_block2, bd, ld);
Expand All @@ -1321,7 +1344,8 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
for (int ld = 0; ld < ld_block2; ++ld) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto mask = is_tail ? ld_tail_mask : P_ALL_ONE;
ld1w(load().s, mask / T_z, ptr(reg_aux_B, B_offset(ld, rd)));
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0);
ld1w(load().s, mask / T_z, ptr(X_DEFAULT_ADDR));

if (brg.req_cal_comp_pads) {
compensation_padding(load(), bcst(), ld, bd_b, bd_e);
Expand All @@ -1348,8 +1372,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,

int rd_loop = 0, rd_tail_size = 0;
if (is_rd_tail) {
rd_tail_size = brg.rdb_tail % brg.rd_step;
if (brg.is_bf16 || brg.is_int8) {
assert(!"unsupported\n");
rd_loop = (rd_tail_size != 0)
? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step
: brg.rdb_tail;
} else
rd_loop = brg.rdb_tail;
} else
Expand All @@ -1360,9 +1387,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
if (is_tail) {
eor(z1.d, z1.d, z1.d);
auto xmm_tmp = z_tmp_1();
add_imm(X_DEFAULT_ADDR, reg_aux_A, rd_tail_size * brg.typesize_A,
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset * brg.typesize_A,
X_TMP_0);
set_preg(P_TMP.b, offset);
set_preg(P_TMP.b, rd_tail_size, X_TMP_0, X_TMP_1);
ld1b(xmm_tmp.b, P_TMP / T_z, ptr(X_DEFAULT_ADDR));
dup(z1.s, xmm_tmp.s[0]);
} else {
Expand All @@ -1377,7 +1404,8 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
} else if (dt == data_type::bf16) {
assert(!"unsupported\n");
} else if (one_of(dt, data_type::s8, data_type::u8)) {
assert(!"unsupported\n");
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset, X_TMP_0);
ld1rw(z1.s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
} else if (dt == data_type::f16) {
assert(!"unsupported\n");
}
Expand All @@ -1389,7 +1417,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
const bool comp_vpad = vpad != 0
&& (brg.req_s8s8_compensation
|| brg.zp_type_a != brgemm_broadcast_t::none);
if (brg.req_cal_comp_pads || comp_vpad) assert(!"unsupported\n");
if (brg.req_cal_comp_pads || comp_vpad)
compute_int8_compensation(
rd_loop, bd_b, bd_e, bd_block, ld_block2, is_ld_tail, vpad);

bool maybe_load_bytes
= (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read)
Expand All @@ -1407,17 +1437,17 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
}
for (int ld = 0; ld < ld_block2; ld++) {
const auto addr = ptr(reg_aux_B, B_offset(ld, rd));
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0);
if (brg.dt_b == data_type::f16) {
assert(!"unsupported\n");
} else if (brg.dt_b == data_type::bf16
&& brg.isa_impl == sve_256) {
assert(!"unsupported\n");
} else if (is_ld_tail) {
ld1w(load().s, ld_tail_mask / T_z, addr);
ld1w(load().s, ld_tail_mask / T_z, ptr(X_DEFAULT_ADDR));
} else {
ld1w(load().s, P_ALL_ONE / T_z, addr);
ld1w(load().s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
}
for (int bd = bd_b; bd < bd_e; bd++) {
auto vmm = accm(ld_block2, bd, ld);
Expand Down Expand Up @@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
const auto bd_by_load_bytes
= (bd >= bd_e - rows_by_load_bytes
|| brg.brgattr.wary_A_k_tail_read);
broadcast(bcst(), A_offset(bd, rd),
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
int should_broadcast = static_cast<int>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean? We have a variable called should_broadcast but then we call broadcast even if it is not true, are they two different kinds of broadcasting? If so it would be great to make that clear in the variable/function names

have_to_load_bytes && bd_by_load_bytes);
broadcast(bcst(), A_offset(bd, rd), should_broadcast,
brg.dt_a);
}
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
for (int ld = 0; ld < ld_block2; ld++) {
Expand Down Expand Up @@ -1561,7 +1593,12 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail,

if (brg.req_s8s8_compensation) { assert(!"unsupported\n"); }
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
assert(!"unsupported\n");
str(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_));
const auto reg32_scratch = WReg(reg_zp_a_input_shift.getIdx());
mov(z_one_bytes().b, 1);
ldr(reg32_scratch, ptr(X_SP, reg_zp_a_val_offs_));
dup(z_zp_a_shift().s, reg32_scratch);
ldr(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_));
}

if (brg.brgattr.max_bs > 1) { mov(reg_BS_loop, reg_BS); }
Expand Down