From 90f43273f4505d123abe16c586f4fe0ae43d4e92 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 8 Apr 2026 08:51:53 -0700 Subject: [PATCH 1/2] [ET-VK] Fix force_fp16 texture bias being silently rejected for CONTIGUOUS_ANY ops MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Pull Request resolved: https://github.com/pytorch/executorch/pull/18770 The `force_fp16` path in `TagMemoryMetaPass` applies `ANY_TEXTURE` to bias ops toward texture storage. However, `try_constrain_with_arg_repset` has a packed-dim compatibility check that requires ALL of the source repset's PDIs to exist in the output repset. `ANY_TEXTURE` has 3 texture layouts (WP, HP, CP) but `CONTIGUOUS_ANY` outputs only support WP, so the check fails and the texture bias is silently dropped. Without the bias, buffer storage cascades from ops that must use buffer (e.g. embedding with vocab exceeding texture limits) into downstream ops that could use texture, causing unnecessary buffer↔texture transitions. Fix: check PDI compatibility against the intersection of arg and source repsets (what would actually be applied) rather than the raw source. The intersection of `ANY_TEXTURE ∩ CONTIGUOUS_ANY` = `WIDTH_PACKED_TEXTURE`, which IS compatible with the output. Authored by Claude. ghstack-source-id: 364280901 @exported-using-ghexport Differential Revision: [D100004702](https://our.internmc.facebook.com/intern/diff/D100004702/) --- backends/vulkan/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index c17f9332e0c..f93fec167eb 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1511,14 +1511,19 @@ def try_constrain_with_arg_repset( if not arg_current_repset.any_in_common(source_repset): return False + # Compute the narrowed repset (intersection of current arg and source). + narrowed = arg_current_repset.make_intersect(source_repset) + if self.sync_primary_io_repr: - if not self.get_out_repset(0).has_compatible_packed_dim_info_set( - source_repset - ): + # Check that the narrowed result is compatible with the output. + # Using the intersection rather than the raw source_repset avoids + # rejecting valid constraints where the source has extra layouts + # (e.g. ANY_TEXTURE includes HP/CP) that don't exist in the output + # but also don't appear in the intersection. + if not self.get_out_repset(0).has_compatible_packed_dim_info_set(narrowed): return False # If this point is reached, then it is possible to constrain - narrowed = arg_current_repset.make_intersect(source_repset) self.args_repset_list[arg_i] = narrowed # Propagate to other synced args via packed-dim compatibility From 82044ebd03fda44b86b815a6707d5731a83cb154 Mon Sep 17 00:00:00 2001 From: ssjia Date: Wed, 8 Apr 2026 08:51:55 -0700 Subject: [PATCH 2/2] [ET-VK] Deduplicate transition clone nodes in TagMemoryMetaPass Pull Request resolved: https://github.com/pytorch/executorch/pull/18771 When the same tensor is consumed by multiple ops that need a different storage representation, the pass previously inserted a separate clone transition for each consumer. Now it caches transition clones keyed by (source_node, target_storage_type, target_layout) and reuses existing clones when the same transition is needed again. For Qwen3 0.6B (8da4w fp16), the embedding output (BUFFER due to vocab_size exceeding texture limits) feeds both rms_norm and add which need TEXTURE. Previously 2 clones were inserted; now 1 clone is shared. Authored by Claude. ghstack-source-id: 364280900 @exported-using-ghexport Differential Revision: [D100004700](https://our.internmc.facebook.com/intern/diff/D100004700/) --- .../vulkan/_passes/tag_memory_meta_pass.py | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 0f8263fa572..6f5cb10f1b2 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -28,11 +28,13 @@ def insert_transition_node( node: torch.fx.Node, arg: torch.fx.Node, arg_node_repr: utils.TensorRepr, -) -> None: +) -> torch.fx.Node: """ Insert a clone node to transition the tensor associated with `arg` to a tensor with the requested representation `arg_node_repr`, and use the cloned node as an argument to `node` instead of `arg`. + + Returns the newly created clone node. """ with graph_module.graph.inserting_before(node): clone_node = graph_module.graph.create_node( @@ -45,6 +47,7 @@ def insert_transition_node( clone_node.meta["spec"].const = False utils.set_node_repr(clone_node, arg_node_repr) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) + return clone_node def set_arg_node_repr_or_transition( @@ -53,6 +56,7 @@ def set_arg_node_repr_or_transition( arg_i: int, arg_node_repr: utils.TensorRepr, dirty: bool, + transition_cache: dict | None = None, ) -> bool: """ Does one of following: @@ -60,7 +64,8 @@ def set_arg_node_repr_or_transition( does not currently have a `node_repr` 2. No-op if the current `node_repr` is already the same as the requested represetnation. 3. Insert a transition node to create a copy of the argument with the desired `node_repr` - if the current `node_repr` is different than what is needed. + if the current `node_repr` is different than what is needed. If a transition clone + already exists for the same (source, target_repr) pair, reuse it. """ arg_node = op_node.args[arg_i] @@ -78,15 +83,33 @@ def single_node_impl(node: torch.fx.Node) -> bool: if cur_node_repr == arg_node_repr: return False + assert utils.is_single_tensor_node(node) + + # Check if a transition clone already exists for this (source, target_repr). + cache_key = ( + node, + arg_node_repr.storage_type, + arg_node_repr.memory_layout, + ) + if transition_cache is not None and cache_key in transition_cache: + cached_clone = transition_cache[cache_key] + node.replace_all_uses_with(cached_clone, lambda x, y=op_node: x == y) + if not dirty: + logger.info( + f"[Vulkan Delegate] Reusing transition for {op_node.format_node()}:" + ) + logger.info(f" arg {arg_i} ({node}): reusing {cached_clone}") + return True + if not dirty: logger.info( f"[Vulkan Delegate] Inserting transition(s) for {op_node.format_node()}:" ) - # Existing node representation is different; insert a transition node - # Currently, the transition node insertion logic can only handle single tensor nodes - assert utils.is_single_tensor_node(node) - insert_transition_node(graph_module, op_node, node, arg_node_repr) + clone_node = insert_transition_node(graph_module, op_node, node, arg_node_repr) + + if transition_cache is not None: + transition_cache[cache_key] = clone_node logger.info(f" arg {arg_i} ({node}): ({cur_node_repr}) -> ({arg_node_repr})") @@ -407,7 +430,10 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( - self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node + self, + graph_module: torch.fx.GraphModule, + op_node: torch.fx.Node, + transition_cache: dict | None = None, ) -> None: """ For an operator representated by `op_node`, get the OpRepSets associated with @@ -458,7 +484,12 @@ def set_op_node_tensor_reprs( if isinstance(arg_node, torch.fx.Node): transitions_inserted = ( set_arg_node_repr_or_transition( - graph_module, op_node, i, arg_node_repr, transitions_inserted + graph_module, + op_node, + i, + arg_node_repr, + transitions_inserted, + transition_cache, ) or transitions_inserted ) @@ -473,12 +504,14 @@ def set_op_node_tensor_reprs( i, arg_node_repr, transitions_inserted, + transition_cache, ) or transitions_inserted ) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + transition_cache: dict = {} for node in graph_module.graph.nodes: - self.set_op_node_tensor_reprs(graph_module, node) + self.set_op_node_tensor_reprs(graph_module, node, transition_cache) return PassResult(graph_module, True)