Skip to content

Commit 6b6d60e

Browse files
author
Kasture Deeksha
committed
cpu: aarch64: support for int8 datatype in brgemm kernel
1 parent b4e96e3 commit 6b6d60e

File tree

3 files changed

+79
-44
lines changed

3 files changed

+79
-44
lines changed

src/cpu/aarch64/brgemm/brgemm.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2020-2025 Intel Corporation
3-
* Copyright 2023-2024 FUJITSU LIMITED
3+
* Copyright 2023-2025 FUJITSU LIMITED
44
* Copyright 2024 Arm Ltd. and affiliates
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -229,12 +229,11 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
229229
brg->LDD = LDD;
230230
const auto dt_d = dst_md->data_type;
231231

232-
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
233-
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
234-
data_type::f32))
235-
&& (!one_of(dt_bias, data_type::undef, data_type::u8, data_type::s8,
236-
data_type::s32, data_type::f32, data_type::bf16)))
237-
return status::unimplemented;
232+
if (brg->is_int8) {
233+
if ((brg->dt_a == data_type::s8 && brg->dt_b == data_type::u8)
234+
|| (dt_bias != data_type::f32) || (dt_d != data_type::f32))
235+
return status::unimplemented;
236+
}
238237
if ((brg->dt_a == data_type::bf16 && brg->dt_b == data_type::bf16)
239238
&& (!one_of(dt_d, data_type::bf16, data_type::f32))
240239
&& (!one_of(dt_bias, data_type::undef, data_type::bf16,
@@ -248,7 +247,7 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
248247
brg->dt_d = dt_d;
249248
brg->typesize_D = types::data_type_size(brg->dt_d);
250249

251-
if (brg->is_int8 || (brg->dt_d == bf16)) return status::unimplemented;
250+
if (brg->dt_d == bf16) return status::unimplemented;
252251

253252
if (!brg->attr) return status::success;
254253

src/cpu/aarch64/brgemm/brgemm_utils.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2022-2023 Intel Corporation
3-
* Copyright 2023-2024 FUJITSU LIMITED
3+
* Copyright 2023-2025 FUJITSU LIMITED
44
* Copyright 2024 Arm Ltd. and affiliates
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -325,7 +325,8 @@ status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
325325
brg->has_int8_vnni = true;
326326

327327
set_brg_vmm(brg); // TODO: Investigate if it is really needed here.
328-
brg->req_s8s8_compensation = brg->is_int8 && brg->dt_a == data_type::s8;
328+
brg->req_s8s8_compensation = (brg->is_int8 && (brg->dt_a == data_type::s8)
329+
&& !isa_has_s8s8(brg->isa_impl));
329330

330331
brg->LDA = (brg->is_row_major()) ? static_cast<int>(LDA)
331332
: static_cast<int>(LDB);
@@ -344,9 +345,7 @@ status_t init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa,
344345
brg->bdb2 = 0;
345346
brg->bdb2_tail = 0;
346347

347-
const bool is_b_in_vnni_format = false;
348-
brg->ld_step
349-
= is_b_in_vnni_format ? data_type_vnni_granularity(brg->dt_b) : 1;
348+
brg->ld_step = data_type_vnni_granularity(brg->dt_b);
350349

351350
const bool has_no_vnni_compute_instruction = false;
352351
brg->rd_step = has_no_vnni_compute_instruction

src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*******************************************************************************
22
* Copyright 2021-2023 Intel Corporation
3-
* Copyright 2024 FUJITSU LIMITED
3+
* Copyright 2024-2025 FUJITSU LIMITED
44
* Copyright 2024 Arm Ltd. and affiliates
55
*
66
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -509,12 +509,20 @@ void jit_brgemm_kernel_t::cvt2ps(data_type_t type_in, const ZReg zmm_in,
509509
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
510510
break;
511511
case data_type::bf16: assert(!"unsupported data type\n"); break;
512-
case data_type::s8: assert(!"unsupported data type\n"); break;
513-
case data_type::u8: assert(!"unsupported data type\n"); break;
512+
case data_type::s8:
513+
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1);
514+
sxtb(z_tmp_1().d, mask / T_m, z_tmp_1().d);
515+
if (store) // Merging
516+
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
517+
break;
518+
case data_type::u8:
519+
LD_MUL_VL(ld1b, z_tmp_1().b, mask, addr, offset - base_offset, 1);
520+
uxtb(z_tmp_1().s, mask / T_m, z_tmp_1().s);
521+
if (store) // Merging
522+
mov(zmm_in.s, ktail_mask / T_m, z_tmp_1().s);
523+
break;
514524
default: assert(!"unsupported data type");
515525
}
516-
if (!one_of(type_in, data_type::f32, data_type::bf16))
517-
assert(!"unsupported data type\n");
518526
}
519527

520528
void jit_brgemm_kernel_t::advance_ldb_post_op_regs() {
@@ -961,7 +969,9 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
961969
}
962970
for (int bd = 0; bd < bd_block; bd++) {
963971
auto zmm = accm(ld_block2, bd, ld);
964-
if (dq2ps_required) { scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s); }
972+
if (dq2ps_required && !brg.with_scales) {
973+
scvtf(zmm.s, P_ALL_ONE / T_m, zmm.s);
974+
}
965975
if (brg.with_bias) { fadd(zmm.s, zmm.s, zmm_bias.s); }
966976
}
967977
}
@@ -989,7 +999,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
989999
auto vmm_zp_c = z_tmp_1();
9901000
if (brg.zp_type_c == brgemm_broadcast_t::per_tensor) {
9911001
add_imm(X_DEFAULT_ADDR, reg_aux_zp_c_values, 0, X_TMP_0);
992-
ldr(z_tmp_2(), ptr(X_DEFAULT_ADDR));
1002+
ld1rw(z_tmp_2().s, P_ALL_ONE, ptr(X_DEFAULT_ADDR));
9931003
scvtf(vmm_zp_c.s, k_mask / T_m, z_tmp_2().s);
9941004
}
9951005
for (int ld = 0; ld < ld_block2; ld++) {
@@ -1048,7 +1058,7 @@ void jit_brgemm_kernel_t::apply_compensation(
10481058
if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
10491059
auto vmm_zp_a_val = z_tmp_2();
10501060
add_imm(X_DEFAULT_ADDR, X_SP, reg_zp_a_val_offs_, X_TMP_0);
1051-
ldr(reg_zp_a_val, ptr(X_DEFAULT_ADDR));
1061+
add_imm(reg_zp_a_val, X_SP, reg_zp_a_val_offs_, X_TMP_0);
10521062
ldr(W_TMP_0, ptr(reg_zp_a_val));
10531063
dup(vmm_zp_a_val.s, W_TMP_0);
10541064

@@ -1096,13 +1106,15 @@ void jit_brgemm_kernel_t::apply_compensation(
10961106
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
10971107
if (IMPLICATION(is_tail, is_superset(brg.isa_impl, sve_512))) {
10981108
const auto mask = is_tail ? k_mask : P_ALL_ONE;
1099-
ld1w(vmm_comp.s, mask / T_z,
1100-
ptr(reg_aux_compensation, comp_offset));
1109+
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
1110+
X_TMP_1);
1111+
ld1w(vmm_comp.s, mask / T_z, ptr(X_DEFAULT_ADDR));
11011112
} else {
11021113
not_(P_TMP.b, P_ALL_ONE, P_NOT_256.b);
11031114
cmplt(P_TMP.s, P_TMP / T_z, z_tail_mask().s, 0);
1104-
ld1w(vmm_comp.s, P_TMP / T_z,
1105-
ptr(reg_aux_compensation, comp_offset));
1115+
add_imm(X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
1116+
X_TMP_1);
1117+
ld1w(vmm_comp.s, P_TMP / T_z, ptr(X_DEFAULT_ADDR));
11061118
}
11071119

11081120
for (int bd = 0; bd < bd_block; bd++) {
@@ -1154,7 +1166,11 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, bool is_bdb_tail,
11541166
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block;
11551167

11561168
if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) {
1157-
assert(!"unsupported\n");
1169+
Label label_store_without_comp;
1170+
cmp_imm(reg_do_comp, 0, X_TMP_0);
1171+
b(EQ, label_store_without_comp);
1172+
apply_compensation(bd_block, ld_block2, is_ld_tail);
1173+
L_aligned(label_store_without_comp);
11581174
}
11591175

11601176
if (need_to_apply_alpha_beta)
@@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() {
12541270
add(reg_aux_B, reg_aux_B, reg_b_offset);
12551271
}
12561272

1257-
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v2, ZReg v3) {
1273+
void jit_brgemm_kernel_t::dot_product(ZReg v1, ZReg v_b, ZReg v_a) {
12581274
if (brg.is_f32) {
1259-
fmla(v1.s, P_ALL_ONE / T_m, v2.s, v3.s);
1275+
fmla(v1.s, P_ALL_ONE / T_m, v_b.s, v_a.s);
12601276
} else if (brg.is_bf16)
12611277
assert(!"unsupported\n");
1262-
else if (brg.is_int8)
1263-
assert(!"unsupported\n");
1264-
else
1278+
else if (brg.is_int8 && isa_has_s8s8(brg.isa_impl)) {
1279+
if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8)
1280+
udot(v1.s, v_b.b, v_a.b);
1281+
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8)
1282+
sdot(v1.s, v_a.b, v_b.b);
1283+
else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8)
1284+
usdot(v1.s, v_a.b, v_b.b);
1285+
else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8)
1286+
assert(!"unsupported\n");
1287+
} else
12651288
assert(!"unsupported\n");
12661289
}
12671290

@@ -1294,7 +1317,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
12941317
if (brg.zp_type_a != brgemm_broadcast_t::none) {
12951318
eor(vmm_tmp.d, vmm_tmp.d, vmm_tmp.d);
12961319
dot_product(vmm_tmp, vmm_load, z_one_bytes());
1297-
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_inp_shift().s);
1320+
mul(vmm_tmp.s, P_ALL_ONE / T_m, z_zp_a_shift().s);
12981321

12991322
for (int bd = bd_b; bd < bd_e; bd++) {
13001323
auto vmm = accm(ld_block2, bd, ld);
@@ -1321,7 +1344,8 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
13211344
for (int ld = 0; ld < ld_block2; ++ld) {
13221345
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
13231346
const auto mask = is_tail ? ld_tail_mask : P_ALL_ONE;
1324-
ld1w(load().s, mask / T_z, ptr(reg_aux_B, B_offset(ld, rd)));
1347+
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0);
1348+
ld1w(load().s, mask / T_z, ptr(X_DEFAULT_ADDR));
13251349

13261350
if (brg.req_cal_comp_pads) {
13271351
compensation_padding(load(), bcst(), ld, bd_b, bd_e);
@@ -1348,8 +1372,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
13481372

13491373
int rd_loop = 0, rd_tail_size = 0;
13501374
if (is_rd_tail) {
1375+
rd_tail_size = brg.rdb_tail % brg.rd_step;
13511376
if (brg.is_bf16 || brg.is_int8) {
1352-
assert(!"unsupported\n");
1377+
rd_loop = (rd_tail_size != 0)
1378+
? ((brg.rdb_tail / brg.rd_step) + 1) * brg.rd_step
1379+
: brg.rdb_tail;
13531380
} else
13541381
rd_loop = brg.rdb_tail;
13551382
} else
@@ -1360,9 +1387,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
13601387
if (is_tail) {
13611388
eor(z1.d, z1.d, z1.d);
13621389
auto xmm_tmp = z_tmp_1();
1363-
add_imm(X_DEFAULT_ADDR, reg_aux_A, rd_tail_size * brg.typesize_A,
1390+
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset * brg.typesize_A,
13641391
X_TMP_0);
1365-
set_preg(P_TMP.b, offset);
1392+
set_preg(P_TMP.b, rd_tail_size, X_TMP_0, X_TMP_1);
13661393
ld1b(xmm_tmp.b, P_TMP / T_z, ptr(X_DEFAULT_ADDR));
13671394
dup(z1.s, xmm_tmp.s[0]);
13681395
} else {
@@ -1377,7 +1404,8 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
13771404
} else if (dt == data_type::bf16) {
13781405
assert(!"unsupported\n");
13791406
} else if (one_of(dt, data_type::s8, data_type::u8)) {
1380-
assert(!"unsupported\n");
1407+
add_imm(X_DEFAULT_ADDR, reg_aux_A, offset, X_TMP_0);
1408+
ld1rw(z1.s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
13811409
} else if (dt == data_type::f16) {
13821410
assert(!"unsupported\n");
13831411
}
@@ -1389,7 +1417,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
13891417
const bool comp_vpad = vpad != 0
13901418
&& (brg.req_s8s8_compensation
13911419
|| brg.zp_type_a != brgemm_broadcast_t::none);
1392-
if (brg.req_cal_comp_pads || comp_vpad) assert(!"unsupported\n");
1420+
if (brg.req_cal_comp_pads || comp_vpad)
1421+
compute_int8_compensation(
1422+
rd_loop, bd_b, bd_e, bd_block, ld_block2, is_ld_tail, vpad);
13931423

13941424
bool maybe_load_bytes
13951425
= (rows_for_rd_tail > 0 || brg.brgattr.wary_A_k_tail_read)
@@ -1407,17 +1437,17 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14071437
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
14081438
}
14091439
for (int ld = 0; ld < ld_block2; ld++) {
1410-
const auto addr = ptr(reg_aux_B, B_offset(ld, rd));
14111440
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
1441+
add_imm(X_DEFAULT_ADDR, reg_aux_B, B_offset(ld, rd), X_TMP_0);
14121442
if (brg.dt_b == data_type::f16) {
14131443
assert(!"unsupported\n");
14141444
} else if (brg.dt_b == data_type::bf16
14151445
&& brg.isa_impl == sve_256) {
14161446
assert(!"unsupported\n");
14171447
} else if (is_ld_tail) {
1418-
ld1w(load().s, ld_tail_mask / T_z, addr);
1448+
ld1w(load().s, ld_tail_mask / T_z, ptr(X_DEFAULT_ADDR));
14191449
} else {
1420-
ld1w(load().s, P_ALL_ONE / T_z, addr);
1450+
ld1w(load().s, P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
14211451
}
14221452
for (int bd = bd_b; bd < bd_e; bd++) {
14231453
auto vmm = accm(ld_block2, bd, ld);
@@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
14701500
const auto bd_by_load_bytes
14711501
= (bd >= bd_e - rows_by_load_bytes
14721502
|| brg.brgattr.wary_A_k_tail_read);
1473-
broadcast(bcst(), A_offset(bd, rd),
1474-
have_to_load_bytes && bd_by_load_bytes, brg.dt_a);
1503+
int should_broadcast = static_cast<int>(
1504+
have_to_load_bytes && bd_by_load_bytes);
1505+
broadcast(bcst(), A_offset(bd, rd), should_broadcast,
1506+
brg.dt_a);
14751507
}
14761508
//The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
14771509
for (int ld = 0; ld < ld_block2; ld++) {
@@ -1561,7 +1593,12 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail,
15611593

15621594
if (brg.req_s8s8_compensation) { assert(!"unsupported\n"); }
15631595
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t::none) {
1564-
assert(!"unsupported\n");
1596+
str(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_));
1597+
const auto reg32_scratch = WReg(reg_zp_a_input_shift.getIdx());
1598+
mov(z_one_bytes().b, 1);
1599+
ldr(reg32_scratch, ptr(X_SP, reg_zp_a_val_offs_));
1600+
dup(z_zp_a_shift().s, reg32_scratch);
1601+
ldr(reg_bdb_loop, ptr(X_SP, reg_bdb_loop_offs_));
15651602
}
15661603

15671604
if (brg.brgattr.max_bs > 1) { mov(reg_BS_loop, reg_BS); }

0 commit comments

Comments
 (0)