From 88544f608f6e118aa6fe2c6df94d65e04fd8a4d2 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 25 Feb 2026 07:26:53 -0800 Subject: [PATCH] [ET-VK] Add Vulkan ops for skin segmentation and EdgeTAM models Implement several missing Vulkan operators needed to reduce graph fragmentation in the skin segmentation and EdgeTAM models. **Skin segmentation ops:** - aten.where.self: already had C++ and GLSL implementations but was missing the Python partitioner registration. - aten.bitwise_and.Tensor: added as a new binary_op shader variant operating on uint8 (bool) tensors. **EdgeTAM partitioning fixes:** - Comparison ops (eq, lt, le, gt, ge): were registered under the generic BinaryOp features which inherited FP_INT_T as the output dtype set. The partitioner correctly rejected these because their outputs are bool tensors. Split them into a dedicated register_comparison_ops registration with outputs_dtypes=BOOL_T. The binary_op.glsl shader already handles bool output via the IS_COMPARISON_OP path (uint8 storage), so no shader changes are needed. - aten.copy.default: not in the op registry, causing a subgraph break in the first-frame model. This op appears when valid_num_points.to() is called with matching dtype (a no-op cast). Add it to RemoveRedundantOpsTransform so it is eliminated before the partitioner runs. Also register it as an ephemeral op as a fallback. The removal logic requires a _src_arg1_ops set to handle the copy.default(self, src) argument order, where the replacement target is args[1] (src) rather than args[0] (self) as in all other redundant ops. Differential Revision: [D94364641](https://our.internmc.facebook.com/intern/diff/D94364641/) [ghstack-poisoned] --- .../vulkan/_passes/remove_redundant_ops.py | 13 ++- backends/vulkan/op_registry.py | 50 ++++++++- .../runtime/graph/ops/glsl/binary_op.yaml | 8 ++ .../graph/ops/glsl/reduce_per_row_buffer.glsl | 2 +- .../vulkan/runtime/graph/ops/glsl/where.glsl | 101 +++++++++++------- .../runtime/graph/ops/impl/BinaryOp.cpp | 2 + .../vulkan/runtime/graph/ops/impl/Where.cpp | 55 ++-------- backends/vulkan/test/op_tests/cases.py | 23 ++++ 8 files changed, 166 insertions(+), 88 deletions(-) diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 25bdd34de70..b95733021fc 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -32,6 +32,13 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, exir_ops.edge.aten.expand_copy.default, + # copy.default(self, src): no-op when src dtype/shape matches self. + exir_ops.edge.aten.copy.default, + } + + # For these ops the meaningful input is args[1] (src), not args[0] (self). + _src_arg1_ops: Set[OpType] = { + exir_ops.edge.aten.copy.default, } def __init__(self) -> None: @@ -41,7 +48,8 @@ def _should_remove(self, node: torch.fx.Node) -> bool: if node.target not in self.redundant_ops: return False - orig_node = node.args[0] + src_arg_idx = 1 if node.target in self._src_arg1_ops else 0 + orig_node = node.args[src_arg_idx] assert isinstance(orig_node, torch.fx.Node) src_dtype = orig_node.meta["val"].dtype @@ -61,7 +69,8 @@ def _remove(self, graph_module: torch.fx.GraphModule) -> None: if not self._should_remove(node): continue - node.replace_all_uses_with(node.args[0]) + src_arg_idx = 1 if node.target in self._src_arg1_ops else 0 + node.replace_all_uses_with(node.args[src_arg_idx]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 855df9d2e74..e9bf7201b23 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -167,6 +167,9 @@ def update_features_impl(op: OpKey): # Guard and assert ops torch.ops.aten._assert_scalar.default, torch.ops.aten.sym_constrain_range_for_size.default, + # copy.default is a no-op when src dtype matches dst dtype; removed by + # RemoveRedundantOpsTransform before execution. + exir_ops.edge.aten.copy.default, ] ) def register_ephemeral_ops(): @@ -231,6 +234,19 @@ def register_clamp(): exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.div.Tensor_mode, exir_ops.edge.aten.pow.Tensor_Tensor, + ] +) +def register_binaryop_cpp_ops(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.FP_INT_T, + supports_resize=True, + supports_highdim=True, + ) + + +@update_features( + [ exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.le.Tensor, @@ -238,10 +254,26 @@ def register_clamp(): exir_ops.edge.aten.ge.Tensor, ] ) -def register_binaryop_cpp_ops(): +def register_comparison_ops(): return OpFeatures( inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_INT_T, + outputs_dtypes=utils.BOOL_T, + supports_resize=True, + supports_highdim=True, + ) + + +# ============================================================================= +# BinaryOp.cpp (bitwise) +# ============================================================================= + + +@update_features(exir_ops.edge.aten.bitwise_and.Tensor) +def register_bitwise_and(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.BOOL_T, supports_resize=True, supports_highdim=True, ) @@ -673,6 +705,7 @@ def register_argreduce_cpp_ops(): return OpFeatures( inputs_storage=utils.ANY_TEXTURE, inputs_dtypes=utils.FP_T, + outputs_dtypes=utils.INT_T, supports_resize=True, supports_highdim=True, are_node_inputs_supported_fn=is_reduce_node_supported, @@ -1157,6 +1190,21 @@ def register_index_select(): ) +# ============================================================================= +# Where.cpp +# ============================================================================= + + +@update_features(exir_ops.edge.aten.where.self) +def register_where(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=[utils.BOOL_T, utils.FP_T, utils.FP_T], + outputs_dtypes=utils.FP_T, + supports_resize=True, + ) + + # ============================================================================= # Arange.cpp # ============================================================================= diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml index ee96b5c05b4..c3d5cd00204 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml @@ -116,3 +116,11 @@ binary_op: - VALUE: half - VALUE: float - VALUE: int32 + - NAME: binary_bitwise_and + OPERATOR: X & Y + generate_variant_forall: + STORAGE: + - VALUE: buffer + - VALUE: texture3d + DTYPE: + - VALUE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl index b0c07e73637..3a63099e7df 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/reduce_per_row_buffer.glsl @@ -115,7 +115,7 @@ void main() { #endif #ifdef OUTPUT_IS_INDICES - t_out[out_bufi] = int(0); // int(local_accum.idx); + t_out[out_bufi] = int(local_accum.idx); #else t_out[out_bufi] = convert_to_T(local_accum.val); #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/where.glsl b/backends/vulkan/runtime/graph/ops/glsl/where.glsl index 281b317e0b5..cab7cf54046 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/where.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/where.glsl @@ -1,5 +1,3 @@ -// where.glsl - /* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. @@ -8,7 +6,6 @@ * LICENSE file in the root directory of this source tree. */ - #version 450 core ${define_required_extensions(STORAGE, DTYPE)} @@ -24,44 +21,50 @@ ${define_active_storage_type(STORAGE)} layout(std430) buffer; +#include "indexing.glslh" + ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_condition", "bool", STORAGE)} ${layout_declare_tensor(B, "r", "t_self", DTYPE, STORAGE)} ${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)} - -#include "indexing_utils.h" - $if STORAGE == "buffer": - ${layout_declare_ubo(B, "int", "out_numl")} - ${layout_declare_ubo(B, "ivec4", "out_strides")} - ${layout_declare_ubo(B, "ivec4", "cond_strides")} - ${layout_declare_ubo(B, "ivec4", "self_strides")} - ${layout_declare_ubo(B, "ivec4", "other_strides")} + ${layout_declare_ubo(B, "BufferMetadata", "outp")} + ${layout_declare_ubo(B, "BufferMetadata", "condp")} + ${layout_declare_ubo(B, "BufferMetadata", "selfp")} + ${layout_declare_ubo(B, "BufferMetadata", "otherp")} $else: - ${layout_declare_ubo(B, "ivec3", "out_limits")} - -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_DIM_ORDER")} - -const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); + ${layout_declare_ubo(B, "TextureMetadata", "outp")} + ${layout_declare_ubo(B, "TextureMetadata", "condp")} + ${layout_declare_ubo(B, "TextureMetadata", "selfp")} + ${layout_declare_ubo(B, "TextureMetadata", "otherp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; #ifdef USING_BUFFER void main() { - int out_bufi = int(gl_GlobalInvocationID.x); - if (out_bufi >= out_numl) { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { return; } - const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, out_dim_order); + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + TensorIndex cond_tidx = out_tidx; + clamp_tensor_idx(condp, cond_tidx); - const int cond_bufi = tidx_to_bufi(out_tidx, cond_strides); - const int self_bufi = tidx_to_bufi(out_tidx, self_strides); - const int other_bufi = tidx_to_bufi(out_tidx, other_strides); + TensorIndex self_tidx = out_tidx; + clamp_tensor_idx(selfp, self_tidx); - COND_T cond = t_condition[cond_bufi] ; + TensorIndex other_tidx = out_tidx; + clamp_tensor_idx(otherp, other_tidx); + + const uint cond_bufi = tensor_idx_to_linear_idx(condp, cond_tidx); + const uint self_bufi = tensor_idx_to_linear_idx(selfp, self_tidx); + const uint other_bufi = tensor_idx_to_linear_idx(otherp, other_tidx); + + COND_T cond = t_condition[cond_bufi]; T v_self = t_self[self_bufi]; T v_other = t_other[other_bufi]; @@ -72,29 +75,49 @@ void main() { } } -#else // !USING_BUFFER +#else // USING_TEXTURE void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { + if (out_of_bounds(out_pos, outp)) { return; } - vec4 cond = load_texel(t_condition, pos); - VEC4_T selftex = load_texel(t_self, pos); - VEC4_T othertex = load_texel(t_other, pos); - - VEC4_T outtex; - - for (int idx = 0; idx < 4; ++idx) { - if (cond[idx] == 1) { - outtex[idx] = selftex[idx]; + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + + VEC4_T outtex = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < limit; comp++) { + TensorIndex4D cond_tidx; + cond_tidx.data = min(out_tidx.data, condp.sizes - 1); + TextureElementIndex cond_elem = + tensor4d_idx_to_texture_element_idx_simple(condp, cond_tidx); + uint cond_val = texelFetch(t_condition, cond_elem.pos, 0)[cond_elem.comp]; + + TensorIndex4D self_tidx; + self_tidx.data = min(out_tidx.data, selfp.sizes - 1); + TextureElementIndex self_elem = + tensor4d_idx_to_texture_element_idx_simple(selfp, self_tidx); + VEC4_T self_texel = texelFetch(t_self, self_elem.pos, 0); + + TensorIndex4D other_tidx; + other_tidx.data = min(out_tidx.data, otherp.sizes - 1); + TextureElementIndex other_elem = + tensor4d_idx_to_texture_element_idx_simple(otherp, other_tidx); + VEC4_T other_texel = texelFetch(t_other, other_elem.pos, 0); + + if (cond_val > 0) { + outtex[comp] = self_texel[self_elem.comp]; } else { - outtex[idx] = othertex[idx]; + outtex[comp] = other_texel[other_elem.comp]; } + + out_tidx.data[outp.packed_dim]++; } - write_texel(t_out, pos, outtex); + + imageStore(t_out, out_pos, outtex); } - #endif // !USING_BUFFER +#endif diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 025b483eab7..92c2fa218ec 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -198,6 +198,7 @@ DEFINE_BINARY_OP_FN(lt); DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); +DEFINE_BINARY_OP_FN(bitwise_and); REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); @@ -212,6 +213,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.le.Tensor, le); VK_REGISTER_OP(aten.gt.Tensor, gt); VK_REGISTER_OP(aten.ge.Tensor, ge); + VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp index c1c482d9967..adb7fb1beca 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Where.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -21,43 +21,13 @@ void resize_where_node( const std::vector& extra_args) { (void)extra_args; const ValueRef out = args.at(0).refs.at(0); - const ValueRef in = args.at(1).refs.at(0); + const ValueRef self = args.at(1).refs.at(1); - const std::vector in_sizes = graph->sizes_of(in); - graph->virtual_resize(out, in_sizes); + const std::vector self_sizes = graph->sizes_of(self); + graph->virtual_resize(out, self_sizes); } -void add_where_texture_node( - ComputeGraph& graph, - const ValueRef cond, - const ValueRef self, - const ValueRef other, - const ValueRef out) { - std::string kernel_name = "where"; - - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - default_pick_global_wg_size, - default_pick_local_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}}, - // Parameter buffers - {graph.logical_limits_ubo(self)}, - // Push Constants - {}, - // Specialization Constants - {graph.hashed_layout_of(out)}, - // Resize Arguments - {}, - // Resizing Logic - resize_where_node)); -} - -void add_where_buffer_node( +void add_where_node( ComputeGraph& graph, const ValueRef cond, const ValueRef self, @@ -69,11 +39,10 @@ void add_where_buffer_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); vkapi::ParamsBindList ubos = { - graph.numel_ubo(out), - graph.strides_ubo(out), - graph.strides_ubo(cond), - graph.strides_ubo(self), - graph.strides_ubo(other)}; + graph.meta_ubo(out), + graph.meta_ubo(cond), + graph.meta_ubo(self), + graph.meta_ubo(other)}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -87,7 +56,7 @@ void add_where_buffer_node( // Push Constants {}, // Specialization Constants - {graph.hashed_layout_of(out)}, + {}, // Resize Arguments {}, // Resizing Logic @@ -100,11 +69,7 @@ void where(ComputeGraph& graph, const std::vector& args) { const ValueRef self = args[args_i++]; const ValueRef other = args[args_i++]; const ValueRef out = args[args_i++]; - if (graph.is_buffer_storage(out)) { - add_where_buffer_node(graph, cond, self, other, out); - } else { - add_where_texture_node(graph, cond, self, other, out); - } + add_where_node(graph, cond, self, other, out); } REGISTER_OPERATORS { diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 534462ed179..5ed354ebab3 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -2001,6 +2001,29 @@ def get_where_inputs(): return test_suite +@register_test_suite("aten.bitwise_and.Tensor") +def get_bitwise_and_inputs(): + test_suite = VkTestSuite( + [ + ((M1, M2), (M1, M2)), + ((S, S1, S2), (S, S1, S2)), + ((XS, S, S1, S2), (XS, S, S1, S2)), + ((1, M1), (1, M1)), + ] + ) + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.storage_types = [ + "utils::kBuffer", + "utils::kTexture3D", + ] + test_suite.dtypes = ["at::kBool"] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + @register_test_suite("aten.pow.Tensor_Scalar") def get_pow_tensor_scalar_inputs(): test_suite = VkTestSuite(