11/*
2- * Copyright (c ) Meta Platforms, Inc. and affiliates.
2+ * Copyright (sums ) Meta Platforms, Inc. and affiliates.
33 * All rights reserved.
44 *
55 * This source code is licensed under the BSD-style license found in the
1414#define VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
1515
1616#define TILE_ROWS ${TILE_ROWS}
17+ #define TILE_TXCOLS ${TILE_TXCOLS}
1718
1819#define NGROUPS 8
1920#define NWORKERS 8
@@ -29,7 +30,10 @@ layout(std430) buffer;
2930
3031${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
3132${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array= False)}
32- ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
33+ $if QUANT_NBITS == 4 :
34+ ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array= False)}
35+ $else :
36+ ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
3337${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array= False)}
3438
3539layout (push_constant) uniform restrict Block {
@@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {
4246
4347layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
4448
45- shared VEC4_T partial_c [NGROUPS][NWORKERS][TILE_ROWS];
49+ shared VEC4_T partial_sums [NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS ];
4650
4751void main() {
48- const uint out_width_ntexels = divup4(out_sizes.x);
49- const uint out_col = (gl_GlobalInvocationID.x % out_width_ntexels) << 2 ;
50- const uint out_row = (gl_GlobalInvocationID.x / out_width_ntexels) * TILE_ROWS;
52+ // txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
53+ $if TILE_TXCOLS > 1 :
54+ const uint global_wg_x = uint (divup(out_sizes.x, 4 * TILE_TXCOLS));
55+ const uint out_txcol = uint (
56+ (gl_GlobalInvocationID.x % global_wg_x) * TILE_TXCOLS);
57+ $else :
58+ const uint global_wg_x = uint (divup4(out_sizes.x));
59+ const uint out_txcol = uint (gl_GlobalInvocationID.x % global_wg_x);
60+
61+ const uint out_row = uint (
62+ (gl_GlobalInvocationID.x / global_wg_x) * TILE_ROWS);
63+
64+ $if QUANT_NBITS == 4 :
65+ const uint weight_txcol = uint (out_txcol / 2 );
5166
5267 const int gid = int (gl_LocalInvocationID.x); // group id
5368 const int wid = int (gl_LocalInvocationID.z); // worker id
@@ -56,46 +71,78 @@ void main() {
5671 return ;
5772 }
5873
59- VEC4_T a [TILE_ROWS];
60- VEC4_T b[ 4 ];
61- VEC4_T local_c [TILE_ROWS];
74+ VEC4_T mat1 [TILE_ROWS];
75+ VEC4_T qmat2[ 4 ][TILE_TXCOLS ];
76+ VEC4_T local_sums [TILE_ROWS][TILE_TXCOLS ];
6277
63- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
64- local_c[i] = VEC4_T(0.0 );
78+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
79+ $for c in range(TILE_TXCOLS):
80+ local_sums[r][${c}] = VEC4_T(0.0 );
6581 }
6682
67- $if SCALES_STORAGE == "buffer ":
68- const VEC4_T scales = VEC4_T(t_scales[out_col >> 2 ]);
69- $else :
70- const VEC4_T scales = VEC4_T(texelFetch(t_scales, ivec2 (out_col >> 2 , 0 ), 0 ));
71-
72- for (int pos = 4 * wid; pos < in_sizes.x; pos += (4 * NWORKERS)) {
73- // Preload t_weight
74- [[unroll]] for (int i = 0 ; i < 4 ; i++ ) {
75- $if WEIGHT_STORAGE == "buffer ":
76- b[i] = t_weight[((pos + i) * weight_sizes.x + out_col) >> 2 ];
83+ VEC4_T scales[TILE_TXCOLS];
84+ $for c in range(TILE_TXCOLS):
85+ $if SCALES_STORAGE == "buffer ":
86+ scales[${c}] = VEC4_T(t_scales[out_txcol + ${c}]);
87+ $else :
88+ scales[${c}] = VEC4_T(
89+ texelFetch(t_scales, ivec2 (out_txcol + ${c}, 0 ), 0 ));
90+
91+ for (int pos = (4 * wid), txpos = wid;
92+ pos < in_sizes.x;
93+ pos += (4 * NWORKERS), txpos += NWORKERS) {
94+ $if WEIGHT_STORAGE == "buffer ":
95+ uint qmat2_bufi;
96+ uint weight_row_txstride = div4(weight_sizes.x);
97+
98+ // Preload weight tensor
99+ [[unroll]] for (int r = 0 ; r < 4 ; r++ ) {
100+ $if QUANT_NBITS == 4 :
101+ $for c in range(0 , TILE_TXCOLS, 2 ):
102+ $if WEIGHT_STORAGE == "buffer ":
103+ qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
104+ const u8vec4 packed_weight_tex = t_weight[qmat2_bufi + ${c}]
105+ $else :
106+ const uvec4 packed_weight_tex = texelFetch(
107+ t_weight, ivec2 (weight_txcol + ${c}, pos + r), 0 );
108+
109+ qmat2[r][${c}] = (VEC4_T((packed_weight_tex & 0xF0) >> 4 ) - 8.0 );
110+ qmat2[r][${c + 1 }] = (VEC4_T(packed_weight_tex & 0x0F) - 8.0 );
77111 $else :
78- b[i] = VEC4_T(texelFetch(t_weight, ivec2 (out_col >> 2 , pos + i), 0 ));
112+ $for c in range(TILE_TXCOLS):
113+ $if WEIGHT_STORAGE == "buffer ":
114+ qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
115+ qmat2[r][${c}] = t_weight[qmat2_bufi + ${c}];
116+ $else :
117+ qmat2[r][${c}] = VEC4_T(
118+ texelFetch(t_weight, ivec2 (out_txcol + ${c}, pos + r), 0 ));
79119 }
80- // Preload t_in
81- for (int i = 0 ; i < TILE_ROWS; i++ ) {
120+
121+ $if IN_STORAGE == "buffer ":
122+ uint in_row_txstride = div4(in_sizes.x);
123+
124+ // Preload input tensor
125+ [[unroll]] for (int i = 0 ; i < TILE_ROWS; i++ ) {
82126 $if IN_STORAGE == "buffer ":
83- a [i] = t_in[(( out_row + i) * in_sizes.x + pos) >> 2 ];
127+ mat1 [i] = t_in[(out_row + i) * in_row_txstride + txpos ];
84128 $else :
85- a[i] = VEC4_T(texelFetch(t_in, ivec3 (pos >> 2 , out_row + i, 0 ), 0 ));
129+ mat1[i] = VEC4_T(
130+ texelFetch(t_in, ivec3 (txpos, out_row + i, 0 ), 0 ));
86131 }
87132
88133 // Accumulate partial output
89- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
90- local_c[i] += a[i].x * b[0 ] +
91- a[i].y * b[1 ] +
92- a[i].z * b[2 ] +
93- a[i].w * b[3 ];
134+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
135+ $for c in range(TILE_TXCOLS):
136+ local_sums[r][${c}] += mat1[r].x * qmat2[0 ][${c}] +
137+ mat1[r].y * qmat2[1 ][${c}] +
138+ mat1[r].z * qmat2[2 ][${c}] +
139+ mat1[r].w * qmat2[3 ][${c}];
94140 }
95141 }
96142
97- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
98- partial_c[gid][wid][i] = local_c[i];
143+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
144+ $for c in range(TILE_TXCOLS):
145+ partial_sums[gid][wid][r][${c}] = local_sums[r][${c}];
99146 }
100147
101148 memoryBarrierShared();
@@ -105,21 +152,33 @@ void main() {
105152 return ;
106153 }
107154
108- VEC4_T c[TILE_ROWS];
155+ VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
156+
157+ for (int r = 0 ; r < TILE_ROWS; ++ r) {
158+ $for c in range(TILE_TXCOLS):
159+ sums[r][${c}] = VEC4_T(0.0 );
109160
110- for (int row = 0 ; row < TILE_ROWS; ++ row) {
111- c[row] = VEC4_T(0.0 );
112161 [[unroll]] for (int worker = 0 ; worker < NWORKERS; ++ worker) {
113- c[row] += partial_c[gid][worker][row];
162+ $for c in range(TILE_TXCOLS):
163+ sums[r][${c}] += partial_sums[gid][worker][r][${c}];
114164 }
115165 }
116166
117- [[unroll]] for (int i = 0 ; i < TILE_ROWS; ++ i) {
118- $if OUT_STORAGE == "buffer ":
119- if (out_row + i < out_sizes.y) {
120- t_out[((out_row + i) * out_sizes.x + out_col) >> 2 ] = c[i] * scales;
121- }
122- $else :
123- imageStore(t_out, ivec3 (out_col >> 2 , out_row + i, 0 ), c[i] * scales);
167+ $if OUT_STORAGE == "buffer ":
168+ uint out_bufi;
169+ uint out_row_txstride = div4(out_sizes.x);
170+
171+ [[unroll]] for (int r = 0 ; r < TILE_ROWS; ++ r) {
172+ $for c in range(TILE_TXCOLS):
173+ $if OUT_STORAGE == "buffer ":
174+ if (out_row + r < out_sizes.y) {
175+ out_bufi = (out_row + r) * out_row_txstride + out_txcol;
176+ t_out[out_bufi + ${c}] = sums[r][${c}] * scales[${c}];
177+ }
178+ $else :
179+ imageStore(
180+ t_out,
181+ ivec3 (out_txcol + ${c}, out_row + r, 0 ),
182+ sums[r][${c}] * scales[${c}]);
124183 }
125184}
0 commit comments