1414#define  VEC4_T ${buffer_gvec_type(DTYPE, 4 )}
1515
1616#define  TILE_ROWS ${TILE_ROWS}
17+ #define  TILE_TXCOLS ${TILE_TXCOLS}
1718
1819#define  NGROUPS 8 
1920#define  NWORKERS 8 
@@ -29,7 +30,10 @@ layout(std430) buffer;
2930
3031${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array= False)}
3132${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array= False)}
32- ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
33+ $if  QUANT_NBITS ==  4 : 
34+   ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array= False)}
35+ $else : 
36+   ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array= False)}
3337${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array= False)}
3438
3539layout (push_constant) uniform  restrict  Block {
@@ -42,12 +46,23 @@ layout(push_constant) uniform restrict Block {
4246
4347layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
4448
45- shared  VEC4_T partial_c [NGROUPS][NWORKERS][TILE_ROWS];
49+ shared  VEC4_T partial_sums [NGROUPS][NWORKERS][TILE_ROWS][TILE_TXCOLS ];
4650
4751void  main() {
48-   const  uint  out_width_ntexels =  divup4(out_sizes.x);
49-   const  uint  out_col =  (gl_GlobalInvocationID.x %  out_width_ntexels) <<  2 ;
50-   const  uint  out_row =  (gl_GlobalInvocationID.x /  out_width_ntexels) *  TILE_ROWS;
52+   //  txcol stands for "texel column". One txcol corresponds to 4 scalar columns.
53+   $if  TILE_TXCOLS >  1 : 
54+     const  uint  global_wg_x =  uint (divup(out_sizes.x, 4  *  TILE_TXCOLS));
55+     const  uint  out_txcol =  uint (
56+       (gl_GlobalInvocationID.x %  global_wg_x) *  TILE_TXCOLS);
57+   $else : 
58+     const  uint  global_wg_x =  uint (divup4(out_sizes.x));
59+     const  uint  out_txcol =  uint (gl_GlobalInvocationID.x %  global_wg_x);
60+ 
61+   const  uint  out_row =  uint (
62+     (gl_GlobalInvocationID.x /  global_wg_x) *  TILE_ROWS);
63+ 
64+   $if  QUANT_NBITS ==  4 : 
65+     const  uint  weight_txcol =  uint (out_txcol /  2 );
5166
5267  const  int  gid =  int (gl_LocalInvocationID.x); //  group id
5368  const  int  wid =  int (gl_LocalInvocationID.z); //  worker id
@@ -56,46 +71,78 @@ void main() {
5671    return ;
5772  }
5873
59-   VEC4_T a [TILE_ROWS];
60-   VEC4_T b[ 4 ];
61-   VEC4_T local_c [TILE_ROWS];
74+   VEC4_T mat1 [TILE_ROWS];
75+   VEC4_T qmat2[ 4 ][TILE_TXCOLS ];
76+   VEC4_T local_sums [TILE_ROWS][TILE_TXCOLS ];
6277
63-   [[unroll]] for  (int  i =  0 ; i <  TILE_ROWS; ++ i) {
64-     local_c[i] =  VEC4_T(0.0 );
78+   [[unroll]] for  (int  r =  0 ; r <  TILE_ROWS; ++ r) {
79+     $for  c in  range(TILE_TXCOLS): 
80+       local_sums[r][${c}] =  VEC4_T(0.0 );
6581  }
6682
67-   $if  SCALES_STORAGE ==  "buffer ": 
68-     const  VEC4_T scales =  VEC4_T(t_scales[out_col >>  2 ]);
69-   $else : 
70-     const  VEC4_T scales =  VEC4_T(texelFetch(t_scales, ivec2 (out_col >>  2 , 0 ), 0 ));
71- 
72-   for  (int  pos =  4  *  wid; pos <  in_sizes.x; pos +=  (4  *  NWORKERS)) {
73-     //  Preload t_weight
74-     [[unroll]] for  (int  i =  0 ; i <  4 ; i++ ) {
75-       $if  WEIGHT_STORAGE ==  "buffer ": 
76-         b[i] =  t_weight[((pos +  i) *  weight_sizes.x +  out_col) >>  2 ];
83+   VEC4_T scales[TILE_TXCOLS];
84+   $for  c in  range(TILE_TXCOLS): 
85+     $if  SCALES_STORAGE ==  "buffer ": 
86+       scales[${c}] =  VEC4_T(t_scales[out_txcol +  ${c}]);
87+     $else : 
88+       scales[${c}] =  VEC4_T(
89+         texelFetch(t_scales, ivec2 (out_txcol +  ${c}, 0 ), 0 ));
90+ 
91+   for  (int  pos =  (4  *  wid), txpos =  wid;
92+        pos <  in_sizes.x;
93+        pos +=  (4  *  NWORKERS), txpos +=  NWORKERS) {
94+     $if  WEIGHT_STORAGE ==  "buffer ": 
95+       uint  qmat2_bufi;
96+       uint  weight_row_txstride =  div4(weight_sizes.x);
97+ 
98+     //  Preload weight tensor
99+     [[unroll]] for  (int  r =  0 ; r <  4 ; r++ ) {
100+       $if  QUANT_NBITS ==  4 : 
101+         $for  c in  range(0 , TILE_TXCOLS, 2 ): 
102+           $if  WEIGHT_STORAGE ==  "buffer ": 
103+             qmat2_bufi =  (pos +  r) *  weight_row_txstride +  weight_txcol;
104+             const  u8vec4 packed_weight_tex =  t_weight[qmat2_bufi +  ${c}]
105+           $else : 
106+             const  uvec4  packed_weight_tex =  texelFetch(
107+               t_weight, ivec2 (weight_txcol +  ${c}, pos +  r), 0 );
108+ 
109+           qmat2[r][${c}] =  (VEC4_T((packed_weight_tex &  0xF0) >>  4 ) -  8.0 );
110+           qmat2[r][${c +  1 }] =  (VEC4_T(packed_weight_tex &  0x0F) -  8.0 );
77111      $else : 
78-         b[i] =  VEC4_T(texelFetch(t_weight, ivec2 (out_col >>  2 , pos +  i), 0 ));
112+         $for  c in  range(TILE_TXCOLS): 
113+           $if  WEIGHT_STORAGE ==  "buffer ": 
114+             qmat2_bufi =  (pos +  r) *  weight_row_txstride +  out_txcol;
115+             qmat2[r][${c}] =  t_weight[qmat2_bufi +  ${c}];
116+           $else : 
117+             qmat2[r][${c}] =  VEC4_T(
118+               texelFetch(t_weight, ivec2 (out_txcol +  ${c}, pos +  r), 0 ));
79119    }
80-     //  Preload t_in
81-     for  (int  i =  0 ; i <  TILE_ROWS; i++ ) {
120+ 
121+     $if  IN_STORAGE ==  "buffer ": 
122+       uint  in_row_txstride =  div4(in_sizes.x);
123+ 
124+     //  Preload input tensor
125+     [[unroll]] for  (int  i =  0 ; i <  TILE_ROWS; i++ ) {
82126      $if  IN_STORAGE ==  "buffer ": 
83-         a [i] =  t_in[(( out_row +  i) *  in_sizes.x  +  pos)  >>   2 ];
127+         mat1 [i] =  t_in[(out_row +  i) *  in_row_txstride  +  txpos ];
84128      $else : 
85-         a[i] =  VEC4_T(texelFetch(t_in, ivec3 (pos >>  2 , out_row +  i, 0 ), 0 ));
129+         mat1[i] =  VEC4_T(
130+           texelFetch(t_in, ivec3 (txpos, out_row +  i, 0 ), 0 ));
86131    }
87132
88133    //  Accumulate partial output
89-     [[unroll]] for  (int  i =  0 ; i <  TILE_ROWS; ++ i) {
90-         local_c[i] +=  a[i].x *  b[0 ] + 
91-                       a[i].y *  b[1 ] + 
92-                       a[i].z *  b[2 ] + 
93-                       a[i].w *  b[3 ];
134+     [[unroll]] for  (int  r =  0 ; r <  TILE_ROWS; ++ r) {
135+       $for  c in  range(TILE_TXCOLS): 
136+         local_sums[r][${c}] +=  mat1[r].x *  qmat2[0 ][${c}] + 
137+                                mat1[r].y *  qmat2[1 ][${c}] + 
138+                                mat1[r].z *  qmat2[2 ][${c}] + 
139+                                mat1[r].w *  qmat2[3 ][${c}];
94140    }
95141  }
96142
97-   [[unroll]] for  (int  i =  0 ; i <  TILE_ROWS; ++ i) {
98-     partial_c[gid][wid][i] =  local_c[i];
143+   [[unroll]] for  (int  r =  0 ; r <  TILE_ROWS; ++ r) {
144+     $for  c in  range(TILE_TXCOLS): 
145+       partial_sums[gid][wid][r][${c}] =  local_sums[r][${c}];
99146  }
100147
101148  memoryBarrierShared();
@@ -105,21 +152,33 @@ void main() {
105152    return ;
106153  }
107154
108-   VEC4_T c[TILE_ROWS];
155+   VEC4_T sums[TILE_ROWS][TILE_TXCOLS];
156+ 
157+   for  (int  r =  0 ; r <  TILE_ROWS; ++ r) {
158+     $for  c in  range(TILE_TXCOLS): 
159+       sums[r][${c}] =  VEC4_T(0.0 );
109160
110-   for  (int  row =  0 ; row <  TILE_ROWS; ++ row) {
111-     c[row] =  VEC4_T(0.0 );
112161    [[unroll]] for  (int  worker =  0 ; worker <  NWORKERS; ++ worker) {
113-       c[row] +=  partial_c[gid][worker][row];
162+       $for  c in  range(TILE_TXCOLS): 
163+         sums[r][${c}] +=  partial_sums[gid][worker][r][${c}];
114164    }
115165  }
116166
117-   [[unroll]] for  (int  i =  0 ; i <  TILE_ROWS; ++ i) {
118-     $if  OUT_STORAGE ==  "buffer ": 
119-       if  (out_row +  i <  out_sizes.y) {
120-         t_out[((out_row +  i) *  out_sizes.x +  out_col) >>  2 ] =  c[i] *  scales;
121-       }
122-     $else : 
123-       imageStore(t_out, ivec3 (out_col >>  2 , out_row +  i, 0 ), c[i] *  scales);
167+   $if  OUT_STORAGE ==  "buffer ": 
168+     uint  out_bufi;
169+     uint  out_row_txstride =  div4(out_sizes.x);
170+ 
171+   [[unroll]] for  (int  r =  0 ; r <  TILE_ROWS; ++ r) {
172+     $for  c in  range(TILE_TXCOLS): 
173+       $if  OUT_STORAGE ==  "buffer ": 
174+         if  (out_row +  r <  out_sizes.y) {
175+           out_bufi =  (out_row +  r) *  out_row_txstride +  out_txcol;
176+           t_out[out_bufi +  ${c}] =  sums[r][${c}] *  scales[${c}];
177+         }
178+       $else : 
179+         imageStore(
180+           t_out,
181+           ivec3 (out_txcol +  ${c}, out_row +  r, 0 ),
182+           sums[r][${c}] *  scales[${c}]);
124183  }
125184}
0 commit comments