From 730dc8b99251a6c495e6525d41282a5afdb962ba Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 18 Nov 2025 00:31:41 -0800 Subject: [PATCH] Tag get_attr in AOTI partitioner so that we can tag away `submodule`s from `torch.cond`. Let's say we have some eager code like this: ```py # Tensor predicate: True if any element is non-zero # Result is a 0-dim bool tensor suitable for torch.cond cache_is_initialized = (cached_keys != 0).any() # Use torch.cond to select branch in a traceable way. # All operands must be (nested) tensors or simple Python values. key_states, value_states = torch.cond( cache_is_initialized, use_cached_kv, recompute_kv, operands=(cached_keys, cached_values, key_value_states), ) ``` Basically we check if KV cache is all zero, if so, we compute KV projections, otherwise we read KV states from KV cache. After torch.export'ing torch.cond, the graph becomes: ``` %any_1 : [num_users=1] = call_function[target=torch.ops.aten.any.default](args = (%ne,), kwargs = {}) %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] %cond : [num_users=2] = call_function[target=torch.ops.higher_order.cond](args = (%any_1, %true_graph_0, %false_graph_0, (%b_cache_cross_attention_cache_layers_0_keys, %b_cache_cross_attention_cache_layers_0_values, %p_decoder_layers_0_encoder_attn_k_proj_weight, %p_decoder_layers_0_encoder_attn_v_proj_bias, %p_decoder_layers_0_encoder_attn_v_proj_weight, %encoder_hidden_states)), kwargs = {}) ``` After tagging and delegate it becomes: ``` graph(): %decoder_input_ids : [num_users=1] = placeholder[target=decoder_input_ids] %encoder_hidden_states : [num_users=1] = placeholder[target=encoder_hidden_states] %cache_position : [num_users=1] = placeholder[target=cache_position] %submodule_0 : [num_users=1] = get_attr[target=submodule_0] %submodule_1 : [num_users=1] = get_attr[target=submodule_1] %submodule_2 : [num_users=1] = get_attr[target=submodule_2] %submodule_3 : [num_users=1] = get_attr[target=submodule_3] %submodule_4 : [num_users=1] = get_attr[target=submodule_4] %submodule_5 : [num_users=1] = get_attr[target=submodule_5] %submodule_6 : [num_users=1] = get_attr[target=submodule_6] %submodule_7 : [num_users=1] = get_attr[target=submodule_7] %fused_tag0 : [num_users=17] = call_module[target=fused_tag0](args = (%decoder_input_ids, %cache_position, %submodule_0, %submodule_1, %encoder_hidden_states, %submodule_2, %submodule_3, %submodule_4, %submodule_5, %submodule_6, %submodule_7), kwargs = {}) %getitem_8 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 0), kwargs = {}) %getitem_9 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 1), kwargs = {}) %getitem_10 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 2), kwargs = {}) %getitem_11 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 3), kwargs = {}) %getitem_12 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 4), kwargs = {}) %getitem_13 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 5), kwargs = {}) %getitem_14 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 6), kwargs = {}) %getitem_15 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 7), kwargs = {}) %getitem_16 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 8), kwargs = {}) %getitem_17 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 9), kwargs = {}) %getitem_18 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 10), kwargs = {}) %getitem_19 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 11), kwargs = {}) %getitem_20 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 12), kwargs = {}) %getitem_21 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 13), kwargs = {}) %getitem_22 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 14), kwargs = {}) %getitem_23 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 15), kwargs = {}) %getitem_24 : [num_users=1] = call_function[target=operator.getitem](args = (%fused_tag0, 16), kwargs = {}) return (getitem_16, getitem_17, getitem_8, getitem_9, getitem_18, getitem_19, getitem_10, getitem_11, getitem_20, getitem_21, getitem_12, getitem_13, getitem_22, getitem_23, getitem_14, getitem_15, getitem_24) ``` But actually those submodules can be delegated away to AOTI. This PR makes sure we tag them properly. --- backends/aoti/aoti_partitioner.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/backends/aoti/aoti_partitioner.py b/backends/aoti/aoti_partitioner.py index 499bc57b735..aa56d3507e9 100644 --- a/backends/aoti/aoti_partitioner.py +++ b/backends/aoti/aoti_partitioner.py @@ -52,10 +52,24 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags: Dict[str, DelegationSpec] = {} tag = "tag0" + # Tag torch.cond and other control flow operations + def is_control_flow(node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + torch.ops.higher_order.cond, + torch.ops.higher_order.map_impl, + torch.ops.higher_order.while_loop, + ] + for node in exported_program.graph.nodes: - if node.op != "call_function": - continue - node.meta["delegation_tag"] = tag + if node.op == "call_function": + node.meta["delegation_tag"] = tag + # Tag get_attr nodes that are used by control flow operations + elif node.op == "get_attr": + # Check if any user is a control flow operation + for user in node.users: + if is_control_flow(user): + node.meta["delegation_tag"] = tag + break partition_tags[tag] = self.delegation_spec