Skip to content

xelpg: jit: gemm: additional f16 accumulation strategies #3417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/gpu/intel/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct gen_gemm_t : public gpu_gemm_t {
? attr_zps.get_data_type(DNNL_ARG_B)
: data_type::s32;
if (swap_ab_) std::swap(ao_type, bo_type);
bool int_acc = utils::one_of(eff_a_type(), s8, u8);
bool int_acc = utils::one_of(eff_a_type(), s8, u8) && !wei_decomp_;
int_acc &= !wei_scales_2d_;
auto co_type = with_bias() ? d->bias_type()
: with_sum_ab() ? d->sum_ab_type
Expand Down Expand Up @@ -297,10 +297,8 @@ struct gen_gemm_t : public gpu_gemm_t {
if (attr()->acc_mode_ == accumulation_mode::relaxed)
set_mode(mode, kernel_desc_t::mode_relaxed_acc);

if (wei_decomp_) {
acc_type = data_type::f32;
if (wei_decomp_)
set_mode(mode, kernel_desc_t::mode_w_decomp);
}

// GEMM kernels down convert the following parameters to
// int/uint32_t
Expand Down
10 changes: 10 additions & 0 deletions src/gpu/intel/jit/gemm/selector/db/kernel.db
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ auto _CATALOG_ = kcatalog::toArray({
{{'E', "gemm", {"F", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "Ixyz"}, "sS32x2 sB16 sB wg 16x2 cb4 ks32 xaf dw vav bo sr bk0 sn dm grf256 sys pab", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {262144, 1048576, 16777216}, {262144, 1048576, 16777216}, {16, 64, 32}, {16, 2, 1}, 1, (WGType) 1, 257, 32768, 0, {4, 4, 4}, {false, false, false}}, {'E', 17, {982213, 473301, 0, 0, 0, 0, 1.74644, 5.1767, 6.10829, 17.1708, 0.0167439, 0.0136956, 0.00599404, 0.999577, 1.37511, 1.22059, 7.53689e-13}}},
{{'E', "gemm", {"F", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "Isxyz"}, "sS64 sB16x2 sB wg 16x2 cb4 ks64 xaf fx dw vav bo sr bk0 sn dm grf256 sys pab", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {262144, 786432, 16777216}, {262144, 786432, 16777216}, {16, 48, 64}, {16, 2, 1}, 1, (WGType) 1, 257, 49152, 0, {4, 4, 4}, {false, false, false}}, {'E', 17, {981842, 458926, 0, 0, 0, 0, 1.5015, 5.00498, 6.31005, 16.9024, 0.0169639, 0.0400974, 0, 0.719651, 1.35848, 1.1845, 9.89658e-13}}},
{{'E', "gemm", {"F", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, 33, -1}, {-1, -1, -1}, {4, 4, 1}, "xyIs"}, "sB16 sB32 aB wg 4x8 cab3x2 ks32 xaf st dw vav bo sr bk0 dm grf256 sys pab", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 393216, 16777216}, {524288, 393216, 16777216}, {32, 24, 32}, {4, 8, 1}, 1, (WGType) 1, 257, 61440, 0, {4, 4, 4}, {false, false, true}}, {'E', 17, {1.01302e+06, 570829, 0, 0, 0, 0, 3.67307, 6.66635, 6.86396, 18.2302, 0.0202076, 0.0155595, 0.00597746, 1, 1.56109, 1.12816, 4.46535e-12}}},
{{'E', "gemm", {"F", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB4x2 aS8x2 aB wg 4x8 kc4 ca4 ks8 nse bo sr sm grf256 dm", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 1, 16384, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {7.16657e+06, 1.57818e+06, 0, 0, 0, 0, 3.86938, 6.19363, 2.90486, 8.10501, 0.0644096, 0.0624624, 0.0048353, 0.950806, 1.02021, 1.10508, -1.30024e-11}}},
{{'E', "gemm", {"F", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB4x2 aS8 aB wg 2x16 kc4 ca4x2 ks8 nse bo sr dm", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 393216, 16777216}, {524288, 393216, 16777216}, {32, 24, 8}, {2, 16, 1}, 1, (WGType) 1, 1, 4096, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {1.4308e+06, 1.0662e+06, 0, 0, 0, 0, 5.11994, 5.6138, 2.1297, 7.62273, 0.0655758, 0.051199, 0.0157761, 0.516429, 1.01847, 1.00258, -2.45704e-14}}},
{{'E', "gemm", {"F", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB2x2 aS8x2 aB wg 2x8 kc2 ca3x2 ks8 nse bo sr grf256 dm", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {2, 8, 1}, 1, (WGType) 1, 1, 3072, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {1.68419e+06, 974445, 0, 0, 0, 0, 4.14746, 5.3584, 1.49125, 6.1661, 0.0647189, 0.0519009, 0.0180995, 0.660031, 1.00378, 1.34304, -4.89949e-11}}},
{{'E', "gemm", {"F", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "s"}, "aB8x2 aS8x2 aB wg 4x8 kc8 ca4 ks8 nse bo sr grf256 dm", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 393216, 16777216}, {524288, 393216, 16777216}, {32, 24, 8}, {4, 8, 1}, 1, (WGType) 1, 1, 8192, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {580644, 1.95257e+06, 0, 0, 0, 0, 2.80946, 5.24923, 1.62269, 5.91446, 0.0656623, 0.0613487, 0.0144173, 0.938521, 1.01882, 1.00041, 6.17388e-15}}},
{{'E', "gemm", {"F", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aS8x2 aB8 aB wg 4x4x2 ikr kc8 cb4x2 ks8 fn nmk nse bo sr sn grf256 dm", {8, (LoopType) 1, 256, {(LoopType) 129, (LoopType) 255, (LoopType) 2}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 8}, {4, 4, 2}, 1, (WGType) 1, 4101, 8192, 16384, {4, 4, 2}, {true, true, true}}, {'E', 17, {2.3029e+06, 1.34131e+06, -582626, 920072, 0, 0, 3.10241, 5.22949, -1.70295, 3.11302, 0.0694663, 0.0658091, 0.00905694, 0.934931, 1.01551, 1.27536, -5.68199e-11}}},
{{'E', "gemm", {"F", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, 2, -1}, {-1, 8, -1}, {-1, 2, -1}, {-1, 8, -1}, {4, 4, 1}, "Iqxy"}, "sB64 sS16 aS wg 2x1x8 ikr af acb sr bk0 bm0 sys pab grf256 rc0", {8, (LoopType) 0, 256, {(LoopType) 0, (LoopType) 1, (LoopType) 2}, {16777216, 8192, 16777216}, {8192, 8192, 16777216}, {16, 8, 64}, {2, 1, 8}, 1, (WGType) 0, 4357, 0, 1024, {4, 4, 4}, {false, false, true}}, {'E', 17, {3.5449e+06, 60571.4, -243099, 15595.1, 0, 0, 1.78243, 2.8889, 2.76679, 6.10171, 0.051381, 0.0216118, 0.0510683, 1, 1.21576, 1.21633, -9.23968e-14}}},
{{'E', "gemm", {"F", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, 2, -1}, {-1, 8, -1}, {-1, 2, -1}, {-1, 8, -1}, {4, 4, 1}, "IQxy"}, "sS64 sB32 aB wg 2x1x8 ikr ki64 sys af k64 grf256 acb di sr nch fm pab rc0", {8, (LoopType) 0, 256, {(LoopType) 0, (LoopType) 1, (LoopType) 2}, {8192, 8192, 16777216}, {8192, 8192, 16777216}, {16, 8, 64}, {2, 1, 8}, 1, (WGType) 0, 4357, 0, 1024, {4, 4, 4}, {false, false, true}}, {'W', 1, {128}}},
{{'E', "gemm", {"F", "O", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {16, 16, 1}, "Ixyz"}, "sB64 sB32x2 sB wg 4x8 ca4x2 ks64 af dw nse hi sr sm dm grf256 cr0 sys pab bk0", {8, (LoopType) 0, 256, {(LoopType) 144, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 16777216}, {32, 16, 64}, {4, 8, 1}, 1, (WGType) 1, 257, 16384, 0, {16, 16, 4}, {false, false, false}}, {'E', 17, {930230, 383972, 0, 0, 0, 0, 1.36662, 2.39816, 6.07666, 16.7056, 0.00930946, 0.00736716, 0.0110739, 1, 1.22963, 1.21426, 6.39235e-14}}},
Expand Down Expand Up @@ -258,6 +263,11 @@ auto _CATALOG_ = kcatalog::toArray({
{{'E', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "xyI"}, "sB16 sB16 aB wg 8x8 cab4 ks16 xaf dw vav bo bk0 sm sn grf256 sys pab l4 sr", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {524288, 524288, 16777216}, {32, 32, 16}, {8, 8, 1}, 1, (WGType) 1, 257, 65536, 0, {2, 2, 4}, {false, false, true}}, {'E', 17, {599248, 1.33458e+06, 0, 0, 0, 0, 5.52852, 5.45748, 6.54024, 18.2766, 0.0211426, 0.0211426, 0, 1, 1.40808, 1.175, 1.4855e-12}}},
{{'E', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "Inpxy"}, "sB16 sB32x2 aB wg 2x8x2 kr ca4x2 ks64 xaf st dw vav bo sr bk0 dm grf256 sys", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 2}, {524288, 262144, 16777216}, {524288, 262144, 16777216}, {32, 16, 64}, {2, 8, 2}, 1, (WGType) 1, 261, 32768, 32768, {4, 4, 4}, {false, false, true}}, {'E', 17, {1.06367e+06, 203220, 74026.4, 363426, 0, 0, 5.22592, 5.83255, 5.53361, 14.4283, 0.0245863, -0.000813833, 0.0346436, 0.773746, 1.49124, 1.1734, 2.5839e-12}}},
{{'E', "gemm", {"H", "H", "S"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "Ip"}, "aB32 aB16 aB wg 2x4x4 kr ca3 ks64 af dw vav bo sr bk0 sm dm sys grf256", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 2}, {262144, 262144, 16777216}, {262144, 262144, 16777216}, {16, 16, 64}, {2, 4, 4}, 1, (WGType) 1, 261, 12288, 12288, {4, 4, 4}, {true, true, true}}, {'E', 17, {1.13288e+06, 578431, 81868.2, 83815, 0, 0, 6.01706, 6.08456, 4.49408, 11.3129, 0.0520799, 0.0402399, 0.0372522, 0.940886, 1.2079, 1.2015, 4.21515e-15}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB8 aS8x2 aB wg 2x8 kc8 ca4x2 ks8 nse bo sr grf256 dm", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {1048576, 524288, 16777216}, {1048576, 524288, 16777216}, {64, 32, 8}, {2, 8, 1}, 1, (WGType) 1, 1, 8192, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {2.3366e+06, 901418, 0, 0, 0, 0, 4.5252, 5.38337, 2.58269, 8.3107, 0.0632927, 0.0574993, 0.00939355, 0.75769, 1.02346, 1.29219, -3.62113e-11}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, "s"}, "aB4x2 aS8x2 aB wg 2x8 kc4 ca3 ks8 nse bo sr grf256 dm", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {1048576, 393216, 16777216}, {1048576, 393216, 16777216}, {64, 24, 8}, {2, 8, 1}, 1, (WGType) 1, 1, 6144, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {4.41016e+06, 1.01312e+06, 0, 0, 0, 0, 4.34359, 4.9417, 2.19106, 7.17771, 0.0632068, 0.0551295, 0.0123284, 0.732487, 1.02596, 1.00174, 8.29021e-14}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB8 aB8 aB wg 4x8 kc8 cab4x2 ks8 nse bo sr sn dm", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 262144, 16777216}, {524288, 262144, 16777216}, {32, 16, 8}, {4, 8, 1}, 1, (WGType) 1, 1, 16384, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {1.74347e+06, 1.17127e+06, 0, 0, 0, 0, 6.25575, 5.60473, 1.54902, 5.826, 0.0636417, 0.0502852, 0.0139854, 0.604199, 1.05183, 1.03637, -1.45818e-14}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB8 aS8x2 aB wg 2x16 kc8 ca3x2 ks8 nse bo sr sm dm", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 196608, 16777216}, {524288, 196608, 16777216}, {32, 12, 8}, {2, 16, 1}, 1, (WGType) 1, 1, 3072, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {9.35401e+06, 1.14199e+06, 0, 0, 0, 0, 6.53667, 4.93606, 0.518581, 4.94042, 0.0626981, 0.0397429, 0.019565, 0.455735, 1.04375, 1.01837, 1.18605e-13}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "N", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {4, 4, 1}, ""}, "aB4x2 aS8x2 aB wg 2x16 kc4 ca3 ks8 nse bo sr dm", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 131072, 16777216}, {524288, 131072, 16777216}, {32, 8, 8}, {2, 16, 1}, 1, (WGType) 1, 1, 3072, 0, {4, 4, 2}, {true, true, true}}, {'E', 17, {1.33663e+06, 1.03862e+06, 0, 0, 0, 0, 5.40567, 4.4686, -0.384802, 4.05935, 0.0635569, 0.0343042, 0.0333069, 0.541742, 1.06988, 1.04013, 1.49008e-13}}},
{{'E', "gemm", {"H", "H", "H"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "qpi"}, "aB8 aB8/4 aB wg 4x8 kc8 cab4 ks8 nse bo sr bk0 sm l4", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {524288, 524288, 16777216}, {8192, 8192, 16777216}, {32, 32, 8}, {4, 8, 1}, 1, (WGType) 1, 1, 24576, 0, {2, 2, 2}, {true, true, true}}, {'W', 1, {1024}}},
{{'E', "gemm", {"H", "H", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "I"}, "aB16x2 aB16x2 aB wg 8x4 cab4 ks16 af dw vav bo bk0 sn grf256 sys l4 sr", {8, (LoopType) 0, 256, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {262144, 524288, 16777216}, {8192, 8192, 16777216}, {16, 32, 16}, {8, 4, 1}, 1, (WGType) 1, 257, 32768, 0, {2, 2, 4}, {true, true, true}}, {'E', 17, {1.30894e+06, 785550, 0, 0, 0, 0, 7.11935, 8.75656, 6.13129, 16.041, 0.0504259, 0.0420693, 0.0674256, 0.73758, 1.20696, 1.20187, 1.15297e-15}}},
{{'E', "gemm", {"H", "H", "S"}, {"T", "T", "N"}}, {-1, -1, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {-1, -1, -1}, {1, 1, 1}, "qpi"}, "aB2x2 aB2x2 aB wg 8x4 kc2 cab4 ks8 nse bo sr bk0 sm l4", {8, (LoopType) 0, 128, {(LoopType) 128, (LoopType) 255, (LoopType) 255}, {262144, 524288, 16777216}, {8192, 8192, 16777216}, {16, 32, 8}, {8, 4, 1}, 1, (WGType) 1, 257, 32768, 0, {2, 2, 4}, {true, true, true}}, {'E', 17, {1.34765e+06, 357923, 0, 0, 0, 0, 11.5919, 11.8886, 6.4384, 17.1449, 0.145301, 0.14134, 0.0125881, 0.886981, 1.17153, 1.00812, 9.37293e-12}}},
Expand Down
Loading