From cf907865a23b2d4671a665ef5b5916806c8bd21d Mon Sep 17 00:00:00 2001 From: Sicheng Stephen Jia Date: Fri, 14 Nov 2025 16:48:35 -0500 Subject: [PATCH] [ET-VK] Add int and bool tensor support for many operators (#15829) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #15829 * #15796 * #15795 * #15794 * #15793 Title says it all! Adds `int32` and `uint8` shader variants to a bunch of operators that don't currently have variants for these dtypes, but should. This should prevent folks from running into dtype crashes at runtime when using the Vulkan delegate. Differential Revision: [D87082724](https://our.internmc.facebook.com/intern/diff/D87082724/) Co-authored-by: ssjia (cherry picked from commit a6c59218cf27ed25495bf34a0892710ebfe1e79b) --- backends/vulkan/runtime/graph/ops/glsl/clone.yaml | 2 ++ backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml | 1 + backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl | 4 ---- backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml | 1 + backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml | 1 + backends/vulkan/runtime/graph/ops/glsl/full.yaml | 1 + backends/vulkan/runtime/graph/ops/glsl/index_select.yaml | 1 + .../vulkan/runtime/graph/ops/glsl/index_select_channel.yaml | 1 + backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml | 2 ++ backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml | 2 ++ backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml | 2 ++ backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml | 2 ++ 12 files changed, 16 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml index 1fdbf506bfd..a85d201046e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/clone.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/clone.yaml @@ -7,5 +7,7 @@ clone: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: clone diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml index 39f96df5e90..36d0b879bdd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_buffer.yaml @@ -6,6 +6,7 @@ concat_buffer: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_buffer NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl index afab0c524d6..0611defa4c3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.glsl @@ -113,8 +113,6 @@ void main() { VEC4_T out_texel = imageLoad(t_out, out_pos); - VEC4_T test_texel = VEC4_T(-1.0); - for (int comp = 0; comp < 4; ++comp) { ivec4 out_tidx = out_read_start_tidx; out_tidx[out_packed_dim] += comp; @@ -124,7 +122,6 @@ void main() { // of the previous input batch; if so, then don't overwrite this texel // element if (out_tidx[concat_dim] < concat_offset) { - test_texel[comp] = -5.0; continue; } @@ -164,7 +161,6 @@ void main() { inp${i}_packed_dim); out_texel[comp] = texelFetch(t_inp${i}, in_posi.xyz, 0)[in_posi.w]; - test_texel[comp] = out_texel[comp]; continue; } else { diff --git a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml index ed5003382a1..d3de77d8ea9 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/concat_texture.yaml @@ -6,6 +6,7 @@ concat_texture: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 shader_variants: - NAME: concat_1_texture3d NUM_INPUTS: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml index 6d90e1fa8b1..887f7893061 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml @@ -6,5 +6,6 @@ expand_buffer: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: expand_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/full.yaml b/backends/vulkan/runtime/graph/ops/glsl/full.yaml index 1a5b0cb235e..5d7a983cae3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/full.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/full.yaml @@ -15,5 +15,6 @@ full: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: full diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml index abef2225cd9..6bf4c71a3c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select.yaml @@ -8,5 +8,6 @@ index_select: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select diff --git a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml index a306e3ce47d..716f7ecf2d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/index_select_channel.yaml @@ -8,5 +8,6 @@ index_select_channel: - VALUE: half - VALUE: float - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: index_select_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml index 02afc3846a2..91306bd4cbf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_channel.yaml @@ -8,5 +8,7 @@ pad_channel: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml index dd74ec9cc28..2eb57291bb2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pad_height_width.yaml @@ -8,5 +8,7 @@ pad_height_width: DTYPE: - VALUE: float - VALUE: half + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: pad_height_width diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml index 4147e82965a..c48237f7568 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_channel.yaml @@ -6,5 +6,7 @@ repeat_channel: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_channel diff --git a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml index 5c284a580c9..f56172dc7f0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/repeat_interleave.yaml @@ -6,5 +6,7 @@ repeat_interleave: DTYPE: - VALUE: half - VALUE: float + - VALUE: int32 + - VALUE: uint8 shader_variants: - NAME: repeat_interleave