1414
1515#define TILE_SIZE ${TILE_SIZE}
1616
17+ #define BATCH_SIZE_X ${BATCH_SIZE_X}
18+
1719#define BATCH_SIZE_Y ${BATCH_SIZE_Y}
1820
1921#define op(X, A, B) ${OPERATOR}
@@ -41,13 +43,15 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4143 * output at a single output location.
4244 */
4345void main() {
46+ // x divided up by batch size is used to determine 3d position
4447 // y divided up by batch size is used to determine 3d position
4548 // since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
46- const int out_limits_y_scaled = (out_limits.y + BATCH_SIZE_Y - 1 ) / BATCH_SIZE_Y;
49+ const ivec2 out_limits_xy_scaled = ivec2 (out_limits.xy + ivec2 (BATCH_SIZE_X, BATCH_SIZE_Y) - 1 ) / ivec2 (BATCH_SIZE_X, BATCH_SIZE_Y) ;
4750
48- u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits .x, out_limits_y_scaled );
51+ u16vec3 pos = idx_to_u16pos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled .x, out_limits_xy_scaled.y );
4952
50- // scale pos.y by batch size, because that's the top pixel to be processed
53+ // scale pos.xy by batch sizes, because that's the top pixel to be processed
54+ pos.x *= uint16_t(BATCH_SIZE_X);
5155 pos.y *= uint16_t(BATCH_SIZE_Y);
5256
5357 // do not process if top pixel does not fit within the output range
@@ -65,46 +69,54 @@ void main() {
6569 const u16vec2 end = ipos + u16vec2(overlay_region.xy);
6670
6771 // sum outputs
68- VEC4_T sum[BATCH_SIZE_Y];
72+ VEC4_T sum[BATCH_SIZE_Y][BATCH_SIZE_X] ;
6973
70- sum[0 ] = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
71- for (int i = 1 ; i < BATCH_SIZE_Y; i++ ) {
72- sum[i] = sum[0 ];
74+ sum[0 ][0 ] = texelFetch(t_bias, u16vec2(pos.z, 0 ), 0 );
75+ for (int y = 0 ; y < BATCH_SIZE_Y; y++ ) {
76+ for (int x = 0 ; x < BATCH_SIZE_X; x++ ) {
77+ sum[y][x] = sum[0 ][0 ];
78+ }
7379 }
7480
7581 // array to store input texels
76- VEC4_T in_texels[TILE_SIZE];
82+ VEC4_T in_texels[TILE_SIZE + BATCH_SIZE_X - 1 ];
7783
7884 // array to store kernel data of previous y
7985 VEC4_T prev_kernel_line[TILE_SIZE];
8086
8187 uint16_t kx = uint16_t(0 );
8288 for (uint16_t y = start.y, i = uint16_t(0 ); i < uint16_t(TILE_SIZE + BATCH_SIZE_Y - 1 ); y += uint16_t(dilation.y), i++ ) {
83- for (uint16_t x = start.x, j = uint16_t(0 ); j < uint16_t(TILE_SIZE); x += uint16_t(dilation.x), j++ ) {
89+ for (uint16_t x = start.x, j = uint16_t(0 ); j < uint16_t(TILE_SIZE + BATCH_SIZE_X - 1 ); x += uint16_t(dilation.x), j++ ) {
8490 in_texels[int (j)] = texelFetch(t_in, u16vec3(x, y, pos.z), 0 );
8591 }
8692
8793 // from 2nd iteration onwards accumulate dot product in 2nd sum
8894 // based on kernel line data fetched in previous iteration and input texel from this iteration
8995 if (i > uint16_t(0 )) {
90- for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ ) {
91- sum[1 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[1 ]);
96+ for (uint16_t s = uint16_t(0 ); s < uint16_t(BATCH_SIZE_X); s++ ) {
97+ for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ ) {
98+ sum[1 ][int (s)] = fma(in_texels[int (j+ s)], prev_kernel_line[int (j)], sum[1 ][int (s)]);
99+ }
92100 }
93101 }
94102
95103 // accumulate dot product in 1st sum only until tile size
96104 if (i < uint16_t(TILE_SIZE)) {
97105 for (uint16_t j = uint16_t(0 ); j < uint16_t(TILE_SIZE); j++ , kx++ ) {
98106 prev_kernel_line[int (j)] = texelFetch(t_kernel, u16vec2(kx, pos.z), 0 );
99- sum[0 ] = fma(in_texels[int (j)], prev_kernel_line[int (j)], sum[0 ]);
107+ for (uint16_t s = uint16_t(0 ); s < uint16_t(BATCH_SIZE_X); s++ ) {
108+ sum[0 ][int (s)] = fma(in_texels[int (j+ s)], prev_kernel_line[int (j)], sum[0 ][int (s)]);
109+ }
100110 }
101111 }
102112 }
103113
104114 for (int i = 0 ; i < BATCH_SIZE_Y; i++ ) {
105- if (any (greaterThanEqual (u16vec3(pos.x, pos.y + i, pos.z), out_limits))) {
106- continue ;
115+ for (int j = 0 ; j < BATCH_SIZE_X; j++ ) {
116+ if (any (greaterThanEqual (u16vec3(pos.x + j, pos.y + i, pos.z), out_limits))) {
117+ continue ;
118+ }
119+ imageStore(t_out, u16vec3(pos.x + j, pos.y + i, pos.z), op(sum[i][j], out_min, out_max));
107120 }
108- imageStore(t_out, u16vec3(pos.x, pos.y + i, pos.z), op(sum[i], out_min, out_max));
109121 }
110122}
0 commit comments