From 4d18ba51b80225eef517ae36286525d2e9f44818 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 17 Sep 2024 14:48:50 -0700 Subject: [PATCH] [ET-VK][BE] vTensor cleanup 5/N - clean up `indexing_utils.h` and clarify function names ## Context NOTE: This diff is a WIP. I still have to replace some callsites and make sure there are no performance regressions. However, I'm seeking some preliminary feedback on the new function names in `indexing_utils.h` The goal of this diff is to clean up the `indexing_utils.h` header file to introduce more consistent terminology for tensor properties and improve the function names so that they are crystal clear on what they are doing. ## Notes for reviewers As for the last diff, the majority of meaningful changes are in `indexing_utils.h` and the rest are changes to function names. There should be no functionality changes, with the exception of re-writing functions which interact with texture position so that input arguments can be const. Differential Revision: [D62901892](https://our.internmc.facebook.com/intern/diff/D62901892/) [ghstack-poisoned] --- .../graph/ops/glsl/addmm_naive_texture3d.glsl | 42 ++- .../graph/ops/glsl/addmm_optimized.glsl | 2 +- .../graph/ops/glsl/buffer_to_nchw.glsl | 4 +- .../ops/glsl/conv2d_dw_prepack_weights.glsl | 2 +- .../ops/glsl/conv2d_prepack_weights.glsl | 2 +- .../conv_transpose2d_prepack_weights.glsl | 2 +- .../runtime/graph/ops/glsl/image_to_nchw.glsl | 12 +- .../graph/ops/glsl/index_select_channel.glsl | 4 +- .../runtime/graph/ops/glsl/indexing_utils.h | 260 +++++++++++------- .../ops/glsl/int8_image_to_nchw_noint8.glsl | 2 +- .../graph/ops/glsl/matmul_naive_buffer.glsl | 6 +- .../graph/ops/glsl/nchw_to_buffer.glsl | 10 +- .../runtime/graph/ops/glsl/nchw_to_image.glsl | 22 +- .../ops/glsl/nchw_to_int8_image_noint8.glsl | 2 +- .../runtime/graph/ops/glsl/q_4w_linear.glsl | 10 +- .../runtime/graph/ops/glsl/q_8w_linear.glsl | 8 +- .../runtime/graph/ops/glsl/slice_channel.glsl | 4 +- .../vulkan/runtime/graph/ops/glsl/view.glsl | 4 +- 18 files changed, 232 insertions(+), 166 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl index 911a2f37ce9..5a95b395fee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl @@ -63,11 +63,11 @@ vec4 get_bias_texel_W_packed(ivec3 logical_pos) { } #endif // HAS_BIAS -vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) { +vec4 matmul_naive_k_dim_packed(const ivec3 out_lpos) { ivec3 mat1_pos; mat1_pos[mat1_axis_map.x] = 0; - mat1_pos[mat1_axis_map.y] = out_mpos.y; - mat1_pos[mat1_axis_map.z] = out_mpos.z; + mat1_pos[mat1_axis_map.y] = out_lpos.y; + mat1_pos[mat1_axis_map.z] = out_lpos.z; #ifdef MAT2_IS_TRANSPOSED const int mat2_k_axis = mat2_axis_map.x; const int mat2_row_axis = mat2_axis_map.y; @@ -88,9 +88,9 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) { // latency. Surprisingly, this doesn't translate to mat1_pos. ivec3 mat2_pos = ivec3(0); mat2_pos[mat2_k_axis] = i; - mat2_pos[mat2_row_axis] = out_mpos.x * 4 + r; + mat2_pos[mat2_row_axis] = out_lpos.x * 4 + r; #ifndef MAT2_IS_TRANSPOSED - mat2_pos[mat2_axis_map.z] = out_mpos.z; + mat2_pos[mat2_axis_map.z] = out_lpos.z; #endif // MAT2_IS_TRANSPOSED sums[r] = dot(mat1_tex, texelFetch(mat2_tensor, mat2_pos, 0)); } @@ -103,16 +103,16 @@ vec4 matmul_naive_k_dim_packed(const ivec3 out_mpos) { return texel; } -vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) { +vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) { ivec3 mat1_pos; mat1_pos[mat1_axis_map.x] = 0; - mat1_pos[mat1_axis_map.y] = out_mpos.y; - mat1_pos[mat1_axis_map.z] = out_mpos.z; + mat1_pos[mat1_axis_map.y] = out_lpos.y; + mat1_pos[mat1_axis_map.z] = out_lpos.z; ivec3 mat2_pos; - mat2_pos[mat2_axis_map.x] = out_mpos.x; + mat2_pos[mat2_axis_map.x] = out_lpos.x; mat2_pos[mat2_axis_map.y] = 0; - mat2_pos[mat2_axis_map.z] = out_mpos.z; + mat2_pos[mat2_axis_map.z] = out_lpos.z; ivec3 mat2_pos_offset = ivec3(0); mat2_pos_offset[mat2_axis_map.y] = 1; @@ -131,9 +131,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) { // On-demand construction of mat2_pos appears to provide the lowest // latency. Surprisingly, this doesn't translate to mat1_pos. ivec3 mat2_pos = ivec3(0); - mat2_pos[mat2_axis_map.x] = out_mpos.x; + mat2_pos[mat2_axis_map.x] = out_lpos.x; mat2_pos[mat2_axis_map.y] = 4 * i + r; - mat2_pos[mat2_axis_map.z] = out_mpos.z; + mat2_pos[mat2_axis_map.z] = out_lpos.z; vec4 mat1_comp_vec = vec4(mat1_tex[r]); texel = fma(mat1_comp_vec, texelFetch(mat2_tensor, mat2_pos, 0), texel); @@ -144,8 +144,8 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_mpos) { } void main() { - const ivec3 out_mpos = ivec3(gl_GlobalInvocationID); - if (any(greaterThanEqual(out_mpos, out_logical_limits))) { + const ivec3 out_lpos = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(out_lpos, out_logical_limits))) { return; } @@ -153,24 +153,22 @@ void main() { #ifdef MAT2_IS_TRANSPOSED if (mat2_packed_dim == W_DIM) { - texel = matmul_naive_k_dim_packed(out_mpos); + texel = matmul_naive_k_dim_packed(out_lpos); } else { - texel = matmul_naive_k_dim_packed_row_dim_packed(out_mpos); + texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); } #else if (mat2_packed_dim == W_DIM) { - texel = matmul_naive_k_dim_packed_row_dim_packed(out_mpos); + texel = matmul_naive_k_dim_packed_row_dim_packed(out_lpos); } else { - texel = matmul_naive_k_dim_packed(out_mpos); + texel = matmul_naive_k_dim_packed(out_lpos); } #endif // MAT2_IS_TRANSPOSED #ifdef HAS_BIAS - vec4 bias_texel = get_bias_texel_W_packed(out_mpos); + vec4 bias_texel = get_bias_texel_W_packed(out_lpos); texel = beta * bias_texel + alpha * texel; #endif // HAS_BIAS - ivec3 out_pos = to_texture_pos(out_mpos, out_axis_map); - - imageStore(out_tensor, out_pos, texel); + imageStore(out_tensor, lpos_to_pos(out_lpos, out_axis_map), texel); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl index 3c0024713fa..ad794d6db49 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -158,7 +158,7 @@ FloatMatrix matmul_partial(const ivec4 out_idx_tl) { // void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) { - ivec3 out_pos = to_texture_pos( + ivec3 out_pos = tidx_to_pos( out_idx_tl, out_sizes, out_axis_map, out_packed_dim); for (int tile_c = 0; diff --git a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl index 58796879e85..1a1c397553a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/buffer_to_nchw.glsl @@ -28,8 +28,8 @@ void main() { return; } - ivec4 t_in_idx = from_nchw_buffer_i(out_id, in_sizes); - const int in_id = to_buffer_id(t_in_idx, in_strides); + ivec4 t_in_idx = nchwi_to_tidx(out_id, in_sizes); + const int in_id = tidx_to_bufi(t_in_idx, in_strides); nchw_buf[out_id] = t_in[in_id]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl index 18202e4a51f..1aeea757bdd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_prepack_weights.glsl @@ -53,7 +53,7 @@ void main() { } // Map tensor_idx to normal buffer_i - const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); + const ivec4 p0 = tidx_to_nchw_ixs(idx, sizes, packed_dim); // Compute modified tensor_idx by inverting the CPU function const int N = original_sizes.w; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl index 493a614ee81..bd2d1294a2b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_prepack_weights.glsl @@ -53,7 +53,7 @@ void main() { } // Map tensor_idx to normal buffer_i - const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); + const ivec4 p0 = tidx_to_nchw_ixs(idx, sizes, packed_dim); // Compute modified tensor_idx by inverting the CPU function const int N = original_sizes.w; diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl index d2978ffe7e6..9efd9967f2b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv_transpose2d_prepack_weights.glsl @@ -53,7 +53,7 @@ void main() { } // Map tensor_idx to normal buffer_i - const ivec4 p0 = get_texel_nchw_buffer_ixs(idx, sizes, packed_dim); + const ivec4 p0 = tidx_to_nchw_ixs(idx, sizes, packed_dim); // Compute modified tensor_idx by inverting the CPU function const int N = original_sizes.w; diff --git a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl index 1e88ffd5975..c46542a81b8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/image_to_nchw.glsl @@ -31,7 +31,7 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; void write_out_texel(VEC4_T texel, ivec4 tensor_idx) { - const ivec4 buf_indices = get_texel_nchw_buffer_ixs( + const ivec4 buf_indices = tidx_to_nchw_ixs( tensor_idx, sizes, packed_dim); @@ -51,13 +51,13 @@ void write_out_texel(VEC4_T texel, ivec4 tensor_idx) { } void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_map, packed_dim); + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map, packed_dim); - if (any(greaterThanEqual(tensor_idx, sizes))) { + if (any(greaterThanEqual(tidx, sizes))) { return; } - const VEC4_T intex = load_texel(t_in, pos); - write_out_texel(intex, tensor_idx); + const VEC4_T intex = load_texel(t_in, lpos_to_pos(lpos, axis_map)); + write_out_texel(intex, tidx); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl index ba60000f3d4..59df6b7c6f7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.glsl @@ -34,11 +34,11 @@ void main() { } const ivec4 idx = to_tensor_idx(out_pos, out_sizes, packed_dim); - const ivec4 buffer_ixs = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim); + const ivec4 buffer_ixs = tidx_to_nchw_ixs(idx, out_sizes, packed_dim); VEC4_T out_texel; for (int i = 0; i < 4; ++i) { - const ivec4 out_idx = from_nchw_buffer_i(buffer_ixs[i], out_sizes); + const ivec4 out_idx = nchwi_to_tidx(buffer_ixs[i], out_sizes); int out_channel = out_idx.z; int in_channel = texelFetch(t_idx, ivec3(out_channel, 0, 0), 0).x; diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h index 4eed38d9ea1..03533807a17 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h @@ -9,6 +9,24 @@ #ifndef INDEXING_UTILS_H #define INDEXING_UTILS_H +/* + * The functions defined in this header file use the following shorthand to + * represent tensor related data structures. + * + * pos - ivec3 texture position, used to fetch from an image texture via the + * texelFetch(image, pos, lod) GLSL function. + * lpos - ivec3 logical position, listed in WHC order. This is a permutation of + * texture position based on a tensor's axis_map. lpos.x is the position + * component that corresponds to the tensor's width dimension, lpos.y is + * the position component that corresponds to the tensor's height dim, + * and so on. + * tidx - ivec4 tensor indices, listed in WHCN order. + * bufi - scalar index into a GPU buffer that backs a tensor. + * nchwi - scalar index into a staging buffer for a tensor. The data in the + * staging buffer is stored in contiguous data layout, irrespective of + * the tensor's strides. + */ + // Width Dim Index, assuming (W, H, C, N) order #define W_DIM 0 // Height, assuming (W, H, C, N) order @@ -16,21 +34,6 @@ // Channels, assuming (W, H, C, N) order #define C_DIM 2 -/* - * Describes which texture axis the "batches" dimension runs along in a 4D - * texture. - * - * Currently it is set to 2 since we represent batches by concatenating along - * the channels dim, which has index 2 in (W, H, C, N) order and maps to the - * depth dimension of a texture, which also corresponds to index 2 in (x, y, z) - * order. - */ -#define BATCH_AXIS 2 - -// -// Basic Indexing Utility Macros and Functions -// - /* * Fast division by 4 using bit shifting */ @@ -39,7 +42,7 @@ /* * Divides input and rounds up to 4 */ -#define divup4(x) ((x + 3) / 4) +#define divup4(x) ((x + 3) >> 2) /* * Aligns input to the next multiple of 4 @@ -47,8 +50,8 @@ #define alignup4(x) ((x + 3) & -4) /* - * Input: (W, H, C, N) strides of a tensor - * Returns: the WHCN index of the fastest moving dimension + * Find the packed dimension of a tensor given its strides. The packed dimension + * is the "fastest moving" dimension which will have a stride of 1. */ int find_packed_dim(const ivec4 strides) { int packed_dim = 0; @@ -62,60 +65,40 @@ int find_packed_dim(const ivec4 strides) { } /* - * Return the elements of a texture position such that the first element is the - * texture coordinate corresponding to the width dimension, the second element - * is the texture coordinate corresponding to the height dimension, and the - * third element is the texture coordinate corresponding to the channels - * dimension. + * Get the staging buffer indices that contain the data of the texel that + * corresponds to the provided tensor index. Since the texel have 4 elements, + * 4 buffer indices will be retrieved. */ -ivec3 get_logical_pos(const ivec3 pos, const ivec4 axis_map) { - return ivec3(pos[axis_map.x], pos[axis_map.y], pos[axis_map.z]); -} - -// -// (w, h, c, n) Tensor Index <-> Contiguous Buffer Index Conversion -// - -/* - * Input: (w, h, c, n) tensor index, (W, H, C, N) sizes of a tensor, which dim - * is packed along a texel - * Output: A ivec4 containing the buffer indices corresponding to each texel - * element. - */ -ivec4 get_texel_nchw_buffer_ixs(ivec4 idx, ivec4 sizes, int packed_dim) { +ivec4 tidx_to_nchw_ixs( + const ivec4 tidx, + const ivec4 sizes, + const int packed_dim) { ivec4 strides = ivec4(1, sizes.x, sizes.x * sizes.y, sizes.x * sizes.y * sizes.z); - int base_i = idx.x * strides.x + idx.y * strides.y + idx.z * strides.z + - idx.w * strides.w; + int base_i = tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z + + tidx.w * strides.w; return base_i + ivec4(0, 1, 2, 3) * strides[packed_dim]; } -/* - * Input: Index into a tensor's data buffer, (W, H, C, N) sizes of a tensor - * Returns: The WCHN index of the tensor that corresponds to the specified - * buffer index, assuming the buffer has contiguous memory layout - */ -ivec4 from_nchw_buffer_i(int buf_i, ivec4 sizes) { +ivec4 nchwi_to_tidx(const int nchwi, const ivec4 sizes) { return ivec4( - buf_i % sizes.x, - (buf_i / (sizes.x)) % sizes.y, - (buf_i / (sizes.x * sizes.y)) % sizes.z, - (buf_i / (sizes.x * sizes.y * sizes.z))); + nchwi % sizes.x, + (nchwi / (sizes.x)) % sizes.y, + (nchwi / (sizes.x * sizes.y)) % sizes.z, + (nchwi / (sizes.x * sizes.y * sizes.z))); } -int to_nchw_buffer_i(const ivec4 tensor_idx, const ivec4 sizes) { - return tensor_idx.w * sizes.x * sizes.y * sizes.z + - tensor_idx.z * sizes.x * sizes.y + tensor_idx.y * sizes.x + tensor_idx.x; +int tidx_to_nchwi(const ivec4 tidx, const ivec4 sizes) { + return tidx.w * sizes.x * sizes.y * sizes.z + tidx.z * sizes.x * sizes.y + + tidx.y * sizes.x + tidx.x; } -/* - * Input: Texel buffer index, (W, H, C, N) strides of a tensor, which dim is - * packed along a texel - * Returns: The (w, h, c, n) tensor index corresponding to the buffer element - */ -ivec4 to_tensor_idx(int buffer_id, const ivec4 strides, const int packed_dim) { +// TODO(ssjia): make this function use dim order so that it can work with any +// dim order. Currently it assumes that the dim order is contiguous, except for +// the packed dim. +ivec4 bufi_to_tidx(int buffer_id, const ivec4 strides, const int packed_dim) { ivec4 idx; for (int i = 3; i >= 0; i--) { if (i != packed_dim) { @@ -127,28 +110,133 @@ ivec4 to_tensor_idx(int buffer_id, const ivec4 strides, const int packed_dim) { return idx; } -/* - * Input: Texel buffer index, (W, H, C, N) strides of a tensor - * Returns: The (w, h, c, n) tensor index corresponding to the buffer element - * - * This is a convenience overload of the above function. If the packed dim is - * not known, it can be found by finding the first dimension with a stride of 1. - * However, this process adds some overhead, so if performance is a concern then - * the above function should be used instead so that the packed dim is provided. - */ -ivec4 to_tensor_idx(int buffer_id, const ivec4 strides) { +// Convenience overload of the above function, which will determine the packed +// dim from the strides automatically so it doesn't have to be passed in as a +// function argument. +ivec4 bufi_to_tidx(const int buffer_id, const ivec4 strides) { int packed_dim = find_packed_dim(strides); - return to_tensor_idx(buffer_id, strides, packed_dim); + return bufi_to_tidx(buffer_id, strides, packed_dim); +} + +int tidx_to_bufi(const ivec4 tidx, ivec4 strides) { + return tidx.x * strides.x + tidx.y * strides.y + tidx.z * strides.z + + tidx.w * strides.w; +} + +ivec4 lpos_to_tidx( + const ivec3 lpos, + const ivec4 sizes, + const ivec4 axis_map, + const int packed_dim) { + int batch_inner_dim = axis_map.w; + int batch_inner_dim_size = batch_inner_dim == packed_dim + ? alignup4(sizes[batch_inner_dim]) + : sizes[batch_inner_dim]; + + // w index is just a placeholder, which will be adjusted later + ivec4 tidx = lpos.xyzx; + // Traversing one texel in the packed dimension traveres 4 tensor elements in + // that dimension + tidx[packed_dim] *= 4; + + if (sizes.w == 1) { + tidx.w = 0; + } else { + tidx.w = tidx[batch_inner_dim] / batch_inner_dim_size; + tidx[batch_inner_dim] %= batch_inner_dim_size; + } + return tidx; } +ivec3 tidx_to_lpos( + const ivec4 tidx, + const ivec4 sizes, + const ivec4 axis_map, + const int packed_dim) { + int batch_inner_dim = axis_map.w; + int batch_inner_dim_size = batch_inner_dim == packed_dim + ? alignup4(sizes[batch_inner_dim]) + : sizes[batch_inner_dim]; + + ivec3 lpos = tidx.xyz; + + // Adjust batch dim if needed + if (sizes.w > 1) { + lpos[batch_inner_dim] += tidx.w * batch_inner_dim_size; + } + // Fast division by 4, since moving 1 texel along the packed dim traverses 4 + // tensor elements. + lpos[packed_dim] >>= 2; + return lpos; +} + +ivec3 tidx_to_pos( + const ivec4 tidx, + const ivec4 sizes, + const ivec4 axis_map, + const int packed_dim) { + int batch_inner_dim = axis_map.w; + int batch_inner_dim_size = batch_inner_dim == packed_dim + ? alignup4(sizes[batch_inner_dim]) + : sizes[batch_inner_dim]; + + ivec3 pos; + for (int dim = 0; dim < 3; ++dim) { + pos[axis_map[dim]] = tidx[dim]; + } + + // Adjust batch dim if needed + if (sizes.w > 1) { + pos[axis_map[batch_inner_dim]] += tidx.w * batch_inner_dim_size; + } + // Fast division by 4, since moving 1 texel along the packed dim traverses 4 + // tensor elements. + pos[axis_map[packed_dim]] >>= 2; + return pos; +} + +ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) { + ivec3 pos; + pos[axis_map.x] = lpos.x; + pos[axis_map.y] = lpos.y; + pos[axis_map.z] = lpos.z; + return pos; +} + +#ifdef USING_BUFFER +#define load_texel(buf, idx) buf[idx] +#elif defined(USING_TEXTURE2D) +#define load_texel(im, pos) texelFetch(im, pos.xy, 0) +#else // defined(USING_TEXTURE3D) +#define load_texel(im, pos) texelFetch(im, pos, 0) +#endif + +#ifdef USING_BUFFER +#define write_texel(buf, idx, texel) buf[idx] = texel +#elif defined(USING_TEXTURE2D) +#define write_texel(im, pos, texel) imageStore(im, pos.xy, texel) +#else // defined(USING_TEXTURE3D) +#define write_texel(im, pos, texel) imageStore(im, pos, texel) +#endif + +/************************ + * Deprecated Functions * + ************************/ + +// The below functions and macros are in the process of being deprecated in +// favor of newer indexing functions that account for axis mapping and have more +// explicit function names and more updated terminology. + /* - * Input: (w, h, c, n) tensor index, (W, H, C, N) strides of the tensor buffer - * Returns: the buffer index corresponding to the specified tensor index + * Describes which texture axis the "batches" dimension runs along in a 4D + * texture. + * + * Currently it is set to 2 since we represent batches by concatenating along + * the channels dim, which has index 2 in (W, H, C, N) order and maps to the + * depth dimension of a texture, which also corresponds to index 2 in (x, y, z) + * order. */ -int to_buffer_id(const ivec4 tensor_idx, ivec4 strides) { - return tensor_idx.x * strides.x + tensor_idx.y * strides.y + - tensor_idx.z * strides.z + tensor_idx.w * strides.w; -} +#define BATCH_AXIS 2 // // (w, h, c, n) Tensor Index <-> (x, y, z) Texture Position Conversion @@ -343,26 +431,6 @@ ivec3 to_texture_pos(const ivec3 logical_pos, const ivec4 axis_map) { return pos; } -// -// Texel Access and Storage -// - -#ifdef USING_BUFFER -#define load_texel(buf, idx) buf[idx] -#elif defined(USING_TEXTURE2D) -#define load_texel(im, pos) texelFetch(im, pos.xy, 0) -#else // defined(USING_TEXTURE3D) -#define load_texel(im, pos) texelFetch(im, pos, 0) -#endif - -#ifdef USING_BUFFER -#define write_texel(buf, idx, texel) buf[idx] = texel -#elif defined(USING_TEXTURE2D) -#define write_texel(im, pos, texel) imageStore(im, pos.xy, texel) -#else // defined(USING_TEXTURE3D) -#define write_texel(im, pos, texel) imageStore(im, pos, texel) -#endif - // // Miscellaneous Utility Functions and Macros // diff --git a/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl index b8a291fd044..fba3560a49b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/int8_image_to_nchw_noint8.glsl @@ -41,7 +41,7 @@ void main() { int in_buf_idx = 4 * out_buf_idx; [[unroll]] for (int i = 0; i < 4; ++i) { - const ivec4 tensor_idx = from_nchw_buffer_i(in_buf_idx, tensor_sizes); + const ivec4 tensor_idx = nchwi_to_tidx(in_buf_idx, tensor_sizes); const ivec4 texture_pos = to_texture_elem_pos( tensor_idx, tensor_sizes, packed_dim); values[i] = load_texel(t_in, texture_pos.xyz)[texture_pos.w]; diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl index 25a6a742779..e51e497a3c9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl @@ -42,9 +42,9 @@ void main() { return; } - int mat1_id = to_buffer_id( + int mat1_id = tidx_to_bufi( ivec4(0, out_idx.y, out_idx.z, out_idx.w), mat1_strides); - int mat2_id = to_buffer_id( + int mat2_id = tidx_to_bufi( ivec4(out_idx.x, 0, out_idx.z, out_idx.w), mat2_strides); T sum = T(0.0); @@ -55,6 +55,6 @@ void main() { mat2_id += mat2_strides.y; } - const int out_id = to_buffer_id(out_idx, out_strides); + const int out_id = tidx_to_bufi(out_idx, out_strides); t_out[out_id] = T(sum); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl index d861972f935..ea4e0d300cc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_buffer.glsl @@ -23,13 +23,13 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int UNUSED_packed_dim = W_DIM; void main() { - int out_id = int(gl_GlobalInvocationID.x); - if (out_id >= numel) { + int out_bufi = int(gl_GlobalInvocationID.x); + if (out_bufi >= numel) { return; } - ivec4 out_idx = to_tensor_idx(out_id, out_strides); - const int in_id = to_nchw_buffer_i(out_idx, out_sizes); + ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides); + const int in_nchwi = tidx_to_nchwi(out_tidx, out_sizes); - t_out[out_id] = nchw_in[in_id]; + t_out[out_bufi] = nchw_in[in_nchwi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl index d553ad3624f..d7dcf116269 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_image.glsl @@ -30,34 +30,34 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; layout(constant_id = 3) const int packed_dim = C_DIM; -VEC4_T read_texel(ivec4 tensor_idx) { - const ivec4 buf_indices = get_texel_nchw_buffer_ixs( - tensor_idx, +VEC4_T read_texel(ivec4 tidx) { + const ivec4 buf_indices = tidx_to_nchw_ixs( + tidx, sizes, packed_dim); VEC4_T texel = VEC4_T(0); - if (tensor_idx[packed_dim] < sizes[packed_dim]) { + if (tidx[packed_dim] < sizes[packed_dim]) { texel.x = SCALAR_T(nchw_in[buf_indices.x]); } - if (tensor_idx[packed_dim] + 1 < sizes[packed_dim]) { + if (tidx[packed_dim] + 1 < sizes[packed_dim]) { texel.y = SCALAR_T(nchw_in[buf_indices.y]); } - if (tensor_idx[packed_dim] + 2 < sizes[packed_dim]) { + if (tidx[packed_dim] + 2 < sizes[packed_dim]) { texel.z = SCALAR_T(nchw_in[buf_indices.z]); } - if (tensor_idx[packed_dim] + 3 < sizes[packed_dim]) { + if (tidx[packed_dim] + 3 < sizes[packed_dim]) { texel.w = SCALAR_T(nchw_in[buf_indices.w]); } return texel; } void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - const ivec4 tensor_idx = to_tensor_idx(pos, sizes, axis_map, packed_dim); - if (any(greaterThanEqual(tensor_idx, sizes))) { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + const ivec4 tidx = lpos_to_tidx(lpos, sizes, axis_map, packed_dim); + if (any(greaterThanEqual(tidx, sizes))) { return; } - write_texel(t_out, pos, read_texel(tensor_idx)); + write_texel(t_out, lpos_to_pos(lpos, axis_map), read_texel(tidx)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl index 48b2abb2af2..30273b1968d 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/nchw_to_int8_image_noint8.glsl @@ -36,7 +36,7 @@ int extend_sign(int x) { } ivec4 read_texel(ivec4 tensor_idx) { - const ivec4 buf_indices = get_texel_nchw_buffer_ixs( + const ivec4 buf_indices = tidx_to_nchw_ixs( tensor_idx, sizes, packed_dim); int shift = (1 << 8) - 1; diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl index 751d513d59d..b755650957f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_4w_linear.glsl @@ -74,20 +74,20 @@ void main() { for (int kb = 0; kb < k_block; kb++) { scale_pos.x = kb; - const int scale_id = to_buffer_id(scale_pos, scales_strides); + const int scale_id = tidx_to_bufi(scale_pos, scales_strides); const float scale = float(t_scales_and_zeros[scale_id]); zero_pos.x = kb; - const int zero_id = to_buffer_id(zero_pos, scales_strides); + const int zero_id = tidx_to_bufi(zero_pos, scales_strides); const float zero = float(t_scales_and_zeros[zero_id]) - scale * 8.0; for(uint idx = 0; idx < group_size && k < K; idx++, k++) { mat1_pos.x = k; - const int mat1_id = to_buffer_id(mat1_pos, mat1_strides); + const int mat1_id = tidx_to_bufi(mat1_pos, mat1_strides); const float mat1_val = float(t_mat1[mat1_id]); mat2_pos.x = k / 2; - const int mat2_id = to_buffer_id(mat2_pos, mat2_strides); + const int mat2_id = tidx_to_bufi(mat2_pos, mat2_strides); // Bitwise op treats sign bit from int8 as a value bit instead, // since there is no uint8_t datatype uint mat2_val = (t_mat2[mat2_id] & 0xFF); @@ -97,7 +97,7 @@ void main() { } } - const int out_id = to_buffer_id(out_pos, out_strides); + const int out_id = tidx_to_bufi(out_pos, out_strides); t_out[out_id] = FLOAT_T(rc); #else // Using texture diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index 7557a7b0c3d..a72df89b634 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -49,14 +49,14 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #ifdef USING_BUFFER void main() { - const int t_id = int(gl_GlobalInvocationID.x); - if (t_id >= out_numel) { + const int out_bufi = int(gl_GlobalInvocationID.x); + if (out_bufi >= out_numel) { return; } - const ivec4 out_idx = to_tensor_idx(t_id, out_strides, 0); + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); - t_out[t_id] = q_8w_linear(out_idx, mat1_sizes.x); + t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); } #else // USING_TEXTURE diff --git a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl index d1562d65762..3b3ca5beb2e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/slice_channel.glsl @@ -43,11 +43,11 @@ void main() { // we calculate the source whcn-coordinate amended with offset-ed channel // value. Then we calculate the actual texture position from the // whcn-coordinate. - const ivec4 buf_indices = get_texel_nchw_buffer_ixs(idx, out_sizes, packed_dim); + const ivec4 buf_indices = tidx_to_nchw_ixs(idx, out_sizes, packed_dim); vec4 outex; for (int i=0;i<4;i++) { - ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], out_sizes); + ivec4 user_coor = nchwi_to_tidx(buf_indices[i], out_sizes); int in_channel = user_coor.z; diff --git a/backends/vulkan/runtime/graph/ops/glsl/view.glsl b/backends/vulkan/runtime/graph/ops/glsl/view.glsl index 0b0f587d1d5..acbc1445600 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view.glsl @@ -39,13 +39,13 @@ void main() { // Assume there is a virtual continous buffer in nchw format. From the output // pos, we first calculate the index in the virual buffer, and then calculate // the input position from the indx. - const ivec4 buf_indices = get_texel_nchw_buffer_ixs(out_tensor_idx, out_sizes, out_packed_dim); + const ivec4 buf_indices = tidx_to_nchw_ixs(out_tensor_idx, out_sizes, out_packed_dim); VEC4_T value = VEC4_T(0); // Need to look up the 4 values in the output texel separately. for (int i = 0 ; i < 4; i++) { if (out_tensor_idx[out_packed_dim]++ < out_sizes[out_packed_dim]) { - ivec4 user_coor = from_nchw_buffer_i(buf_indices[i], in_sizes); + ivec4 user_coor = nchwi_to_tidx(buf_indices[i], in_sizes); ivec4 in_pos_elem = to_texture_elem_pos(user_coor, in_sizes, in_packed_dim); VEC4_T intex = texelFetch(t_in, in_pos_elem.xyz, 0); value[i] = intex[in_pos_elem.w];