1
1
/* ******************************************************************************
2
2
* Copyright 2021-2023 Intel Corporation
3
- * Copyright 2024 FUJITSU LIMITED
3
+ * Copyright 2024-2025 FUJITSU LIMITED
4
4
* Copyright 2024 Arm Ltd. and affiliates
5
5
*
6
6
* Licensed under the Apache License, Version 2.0 (the "License");
@@ -509,12 +509,20 @@ void jit_brgemm_kernel_t::cvt2ps(data_type_t type_in, const ZReg zmm_in,
509
509
mov (zmm_in.s , ktail_mask / T_m, z_tmp_1 ().s );
510
510
break ;
511
511
case data_type::bf16 : assert (!" unsupported data type\n " ); break ;
512
- case data_type::s8: assert (!" unsupported data type\n " ); break ;
513
- case data_type::u8 : assert (!" unsupported data type\n " ); break ;
512
+ case data_type::s8:
513
+ LD_MUL_VL (ld1b, z_tmp_1 ().b , mask, addr, offset - base_offset, 1 );
514
+ sxtb (z_tmp_1 ().d , mask / T_m, z_tmp_1 ().d );
515
+ if (store) // Merging
516
+ mov (zmm_in.s , ktail_mask / T_m, z_tmp_1 ().s );
517
+ break ;
518
+ case data_type::u8 :
519
+ LD_MUL_VL (ld1b, z_tmp_1 ().b , mask, addr, offset - base_offset, 1 );
520
+ uxtb (z_tmp_1 ().s , mask / T_m, z_tmp_1 ().s );
521
+ if (store) // Merging
522
+ mov (zmm_in.s , ktail_mask / T_m, z_tmp_1 ().s );
523
+ break ;
514
524
default : assert (!" unsupported data type" );
515
525
}
516
- if (!one_of (type_in, data_type::f32 , data_type::bf16 ))
517
- assert (!" unsupported data type\n " );
518
526
}
519
527
520
528
void jit_brgemm_kernel_t::advance_ldb_post_op_regs () {
@@ -961,7 +969,9 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
961
969
}
962
970
for (int bd = 0 ; bd < bd_block; bd++) {
963
971
auto zmm = accm (ld_block2, bd, ld);
964
- if (dq2ps_required) { scvtf (zmm.s , P_ALL_ONE / T_m, zmm.s ); }
972
+ if (dq2ps_required && !brg.with_scales ) {
973
+ scvtf (zmm.s , P_ALL_ONE / T_m, zmm.s );
974
+ }
965
975
if (brg.with_bias ) { fadd (zmm.s , zmm.s , zmm_bias.s ); }
966
976
}
967
977
}
@@ -989,7 +999,7 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(
989
999
auto vmm_zp_c = z_tmp_1 ();
990
1000
if (brg.zp_type_c == brgemm_broadcast_t ::per_tensor) {
991
1001
add_imm (X_DEFAULT_ADDR, reg_aux_zp_c_values, 0 , X_TMP_0);
992
- ldr (z_tmp_2 (), ptr (X_DEFAULT_ADDR));
1002
+ ld1rw (z_tmp_2 (). s , P_ALL_ONE , ptr (X_DEFAULT_ADDR));
993
1003
scvtf (vmm_zp_c.s , k_mask / T_m, z_tmp_2 ().s );
994
1004
}
995
1005
for (int ld = 0 ; ld < ld_block2; ld++) {
@@ -1048,7 +1058,7 @@ void jit_brgemm_kernel_t::apply_compensation(
1048
1058
if (!brg.req_cal_comp_pads && brg.zp_type_a != brgemm_broadcast_t ::none) {
1049
1059
auto vmm_zp_a_val = z_tmp_2 ();
1050
1060
add_imm (X_DEFAULT_ADDR, X_SP, reg_zp_a_val_offs_, X_TMP_0);
1051
- ldr (reg_zp_a_val, ptr (X_DEFAULT_ADDR) );
1061
+ add_imm (reg_zp_a_val, X_SP, reg_zp_a_val_offs_, X_TMP_0 );
1052
1062
ldr (W_TMP_0, ptr (reg_zp_a_val));
1053
1063
dup (vmm_zp_a_val.s , W_TMP_0);
1054
1064
@@ -1096,13 +1106,15 @@ void jit_brgemm_kernel_t::apply_compensation(
1096
1106
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
1097
1107
if (IMPLICATION (is_tail, is_superset (brg.isa_impl , sve_512))) {
1098
1108
const auto mask = is_tail ? k_mask : P_ALL_ONE;
1099
- ld1w (vmm_comp.s , mask / T_z,
1100
- ptr (reg_aux_compensation, comp_offset));
1109
+ add_imm (X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
1110
+ X_TMP_1);
1111
+ ld1w (vmm_comp.s , mask / T_z, ptr (X_DEFAULT_ADDR));
1101
1112
} else {
1102
1113
not_ (P_TMP.b , P_ALL_ONE, P_NOT_256.b );
1103
1114
cmplt (P_TMP.s , P_TMP / T_z, z_tail_mask ().s , 0 );
1104
- ld1w (vmm_comp.s , P_TMP / T_z,
1105
- ptr (reg_aux_compensation, comp_offset));
1115
+ add_imm (X_DEFAULT_ADDR, reg_aux_compensation, comp_offset,
1116
+ X_TMP_1);
1117
+ ld1w (vmm_comp.s , P_TMP / T_z, ptr (X_DEFAULT_ADDR));
1106
1118
}
1107
1119
1108
1120
for (int bd = 0 ; bd < bd_block; bd++) {
@@ -1154,7 +1166,11 @@ void jit_brgemm_kernel_t::store_accumulators(int bd_block2, bool is_bdb_tail,
1154
1166
int bd_block = (is_bdb_tail) ? brg.bdb_tail : brg.bd_block ;
1155
1167
1156
1168
if (brg.is_int8 && (brg.req_s8s8_compensation || has_zero_points)) {
1157
- assert (!" unsupported\n " );
1169
+ Label label_store_without_comp;
1170
+ cmp_imm (reg_do_comp, 0 , X_TMP_0);
1171
+ b (EQ, label_store_without_comp);
1172
+ apply_compensation (bd_block, ld_block2, is_ld_tail);
1173
+ L_aligned (label_store_without_comp);
1158
1174
}
1159
1175
1160
1176
if (need_to_apply_alpha_beta)
@@ -1254,14 +1270,21 @@ void jit_brgemm_kernel_t::set_A_B_matrices() {
1254
1270
add (reg_aux_B, reg_aux_B, reg_b_offset);
1255
1271
}
1256
1272
1257
- void jit_brgemm_kernel_t::dot_product (ZReg v1, ZReg v2 , ZReg v3 ) {
1273
+ void jit_brgemm_kernel_t::dot_product (ZReg v1, ZReg v_b , ZReg v_a ) {
1258
1274
if (brg.is_f32 ) {
1259
- fmla (v1.s , P_ALL_ONE / T_m, v2 .s , v3 .s );
1275
+ fmla (v1.s , P_ALL_ONE / T_m, v_b .s , v_a .s );
1260
1276
} else if (brg.is_bf16 )
1261
1277
assert (!" unsupported\n " );
1262
- else if (brg.is_int8 )
1263
- assert (!" unsupported\n " );
1264
- else
1278
+ else if (brg.is_int8 && isa_has_s8s8 (brg.isa_impl )) {
1279
+ if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::u8 )
1280
+ udot (v1.s , v_b.b , v_a.b );
1281
+ else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::s8)
1282
+ sdot (v1.s , v_a.b , v_b.b );
1283
+ else if (brg.dt_a == data_type::u8 && brg.dt_b == data_type::s8)
1284
+ usdot (v1.s , v_a.b , v_b.b );
1285
+ else if (brg.dt_a == data_type::s8 && brg.dt_b == data_type::u8 )
1286
+ assert (!" unsupported\n " );
1287
+ } else
1265
1288
assert (!" unsupported\n " );
1266
1289
}
1267
1290
@@ -1294,7 +1317,7 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
1294
1317
if (brg.zp_type_a != brgemm_broadcast_t ::none) {
1295
1318
eor (vmm_tmp.d , vmm_tmp.d , vmm_tmp.d );
1296
1319
dot_product (vmm_tmp, vmm_load, z_one_bytes ());
1297
- mul (vmm_tmp.s , P_ALL_ONE / T_m, z_inp_shift ().s );
1320
+ mul (vmm_tmp.s , P_ALL_ONE / T_m, z_zp_a_shift ().s );
1298
1321
1299
1322
for (int bd = bd_b; bd < bd_e; bd++) {
1300
1323
auto vmm = accm (ld_block2, bd, ld);
@@ -1321,7 +1344,8 @@ void jit_brgemm_kernel_t::compute_int8_compensation(int rd_loop, int bd_b,
1321
1344
for (int ld = 0 ; ld < ld_block2; ++ld) {
1322
1345
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
1323
1346
const auto mask = is_tail ? ld_tail_mask : P_ALL_ONE;
1324
- ld1w (load ().s , mask / T_z, ptr (reg_aux_B, B_offset (ld, rd)));
1347
+ add_imm (X_DEFAULT_ADDR, reg_aux_B, B_offset (ld, rd), X_TMP_0);
1348
+ ld1w (load ().s , mask / T_z, ptr (X_DEFAULT_ADDR));
1325
1349
1326
1350
if (brg.req_cal_comp_pads ) {
1327
1351
compensation_padding (load (), bcst (), ld, bd_b, bd_e);
@@ -1348,8 +1372,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1348
1372
1349
1373
int rd_loop = 0 , rd_tail_size = 0 ;
1350
1374
if (is_rd_tail) {
1375
+ rd_tail_size = brg.rdb_tail % brg.rd_step ;
1351
1376
if (brg.is_bf16 || brg.is_int8 ) {
1352
- assert (!" unsupported\n " );
1377
+ rd_loop = (rd_tail_size != 0 )
1378
+ ? ((brg.rdb_tail / brg.rd_step ) + 1 ) * brg.rd_step
1379
+ : brg.rdb_tail ;
1353
1380
} else
1354
1381
rd_loop = brg.rdb_tail ;
1355
1382
} else
@@ -1360,9 +1387,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1360
1387
if (is_tail) {
1361
1388
eor (z1.d , z1.d , z1.d );
1362
1389
auto xmm_tmp = z_tmp_1 ();
1363
- add_imm (X_DEFAULT_ADDR, reg_aux_A, rd_tail_size * brg.typesize_A ,
1390
+ add_imm (X_DEFAULT_ADDR, reg_aux_A, offset * brg.typesize_A ,
1364
1391
X_TMP_0);
1365
- set_preg (P_TMP.b , offset );
1392
+ set_preg (P_TMP.b , rd_tail_size, X_TMP_0, X_TMP_1 );
1366
1393
ld1b (xmm_tmp.b , P_TMP / T_z, ptr (X_DEFAULT_ADDR));
1367
1394
dup (z1.s , xmm_tmp.s [0 ]);
1368
1395
} else {
@@ -1377,7 +1404,8 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1377
1404
} else if (dt == data_type::bf16 ) {
1378
1405
assert (!" unsupported\n " );
1379
1406
} else if (one_of (dt, data_type::s8, data_type::u8 )) {
1380
- assert (!" unsupported\n " );
1407
+ add_imm (X_DEFAULT_ADDR, reg_aux_A, offset, X_TMP_0);
1408
+ ld1rw (z1.s , P_ALL_ONE / T_z, ptr (X_DEFAULT_ADDR));
1381
1409
} else if (dt == data_type::f16 ) {
1382
1410
assert (!" unsupported\n " );
1383
1411
}
@@ -1389,7 +1417,9 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1389
1417
const bool comp_vpad = vpad != 0
1390
1418
&& (brg.req_s8s8_compensation
1391
1419
|| brg.zp_type_a != brgemm_broadcast_t ::none);
1392
- if (brg.req_cal_comp_pads || comp_vpad) assert (!" unsupported\n " );
1420
+ if (brg.req_cal_comp_pads || comp_vpad)
1421
+ compute_int8_compensation (
1422
+ rd_loop, bd_b, bd_e, bd_block, ld_block2, is_ld_tail, vpad);
1393
1423
1394
1424
bool maybe_load_bytes
1395
1425
= (rows_for_rd_tail > 0 || brg.brgattr .wary_A_k_tail_read )
@@ -1407,17 +1437,17 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1407
1437
have_to_load_bytes && bd_by_load_bytes, brg.dt_a );
1408
1438
}
1409
1439
for (int ld = 0 ; ld < ld_block2; ld++) {
1410
- const auto addr = ptr (reg_aux_B, B_offset (ld, rd));
1411
1440
const auto mask = is_ld_tail ? ld_tail_mask : P_ALL_ONE;
1441
+ add_imm (X_DEFAULT_ADDR, reg_aux_B, B_offset (ld, rd), X_TMP_0);
1412
1442
if (brg.dt_b == data_type::f16 ) {
1413
1443
assert (!" unsupported\n " );
1414
1444
} else if (brg.dt_b == data_type::bf16
1415
1445
&& brg.isa_impl == sve_256) {
1416
1446
assert (!" unsupported\n " );
1417
1447
} else if (is_ld_tail) {
1418
- ld1w (load ().s , ld_tail_mask / T_z, addr );
1448
+ ld1w (load ().s , ld_tail_mask / T_z, ptr (X_DEFAULT_ADDR) );
1419
1449
} else {
1420
- ld1w (load ().s , P_ALL_ONE / T_z, addr );
1450
+ ld1w (load ().s , P_ALL_ONE / T_z, ptr (X_DEFAULT_ADDR) );
1421
1451
}
1422
1452
for (int bd = bd_b; bd < bd_e; bd++) {
1423
1453
auto vmm = accm (ld_block2, bd, ld);
@@ -1470,8 +1500,10 @@ void jit_brgemm_kernel_t::gemm_microkernel_sve512(int bd_block2,
1470
1500
const auto bd_by_load_bytes
1471
1501
= (bd >= bd_e - rows_by_load_bytes
1472
1502
|| brg.brgattr .wary_A_k_tail_read );
1473
- broadcast (bcst (), A_offset (bd, rd),
1474
- have_to_load_bytes && bd_by_load_bytes, brg.dt_a );
1503
+ int should_broadcast = static_cast <int >(
1504
+ have_to_load_bytes && bd_by_load_bytes);
1505
+ broadcast (bcst (), A_offset (bd, rd), should_broadcast,
1506
+ brg.dt_a );
1475
1507
}
1476
1508
// The current implementaion of prefetch is not giving any gain in performance but is rather introducing some latency. Therefore it is removed util a new useful implementation is deviced.
1477
1509
for (int ld = 0 ; ld < ld_block2; ld++) {
@@ -1561,7 +1593,12 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail,
1561
1593
1562
1594
if (brg.req_s8s8_compensation ) { assert (!" unsupported\n " ); }
1563
1595
if (need_comp_pads && brg.zp_type_a != brgemm_broadcast_t ::none) {
1564
- assert (!" unsupported\n " );
1596
+ str (reg_bdb_loop, ptr (X_SP, reg_bdb_loop_offs_));
1597
+ const auto reg32_scratch = WReg (reg_zp_a_input_shift.getIdx ());
1598
+ mov (z_one_bytes ().b , 1 );
1599
+ ldr (reg32_scratch, ptr (X_SP, reg_zp_a_val_offs_));
1600
+ dup (z_zp_a_shift ().s , reg32_scratch);
1601
+ ldr (reg_bdb_loop, ptr (X_SP, reg_bdb_loop_offs_));
1565
1602
}
1566
1603
1567
1604
if (brg.brgattr .max_bs > 1 ) { mov (reg_BS_loop, reg_BS); }
0 commit comments