Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Oct 15, 2025

Stack from ghstack (oldest at bottom):

This PR fixes 2 issues with local_mapping token-choice moe. Splits from the fw token dispatch result in tensors with unbacked shapes and these unbacked shapes are fully contained in the a2as, and should not leak outside of the joint graph. The HOP body fw and bw are expected to coerce back to static shapes (due to adding it with shared experts output) before returning.

routed_output: "bf16[u0 + u1 + u10 + u11 + u12 + u13 + u14 + u15 + u16 + u17 + u18 + u19 + u2 + u20 + u21 + u22 + u23 + u24 + u25 + u26 + u27 + u28 + u29 + u3 + u30 + u31 + u32 + u33 + u34 + u35 + u36 + u37 + u38 + u39 + u4 + u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + u48 + u49 + u5 + u50 + u51 + u52 + u53 + u54 + u55 + u56 + u57 + u58 + u59 + u6 + u60 + u61 + u62 + u63 + u7 + u8 + u9, 2048]" = torch.ops.higher_order.autograd_function_apply(fwd_body_1, bwd_body_1, out_1, item, item_1, item_2, item_3, item_4, item_5, item_6, item_7, item_8, item_9, item_10, item_11, item_12, item_13, item_14, item_15, item_16, item_17, item_18, item_19, item_20, item_21, item_22, item_23, item_24, item_25, item_26, item_27, item_28, item_29, item_30, item_31, item_32, item_33, item_34, item_35, item_36, item_37, item_38, item_39, item_40, item_41, item_42, item_43, item_44, item_45, item_46, item_47, item_48, item_49, item_50, item_51, item_52, item_53, item_54, item_55, item_56, item_57, item_58, item_59, item_60, item_61, item_62, item_63, item_64, item_65, item_66, item_67, item_68, item_69, item_70, item_71, item_72, item_73, item_74, item_75, item_76, item_77, item_78, item_79, item_80, item_81, item_82, item_83, item_84, item_85, item_86, item_87, item_88, item_89, item_90, item_91, item_92, item_93, item_94, item_95, item_96, item_97, item_98, item_99, item_100, item_101, item_102, item_103, item_104, item_105, item_106, item_107, item_108, item_109, item_110, item_111, item_112, item_113, item_114, item_115, item_116, item_117, item_118, item_119, item_120, item_121, item_122, item_123, item_124, item_125, item_126, item_127, args_tensor_mask = [True, False, False, False], non_differentiable_idx = []);  fwd_body_1 = bwd_body_1 = out_1 = item = item_1 = item_2 = item_3 = item_4 = item_5 = item_6 = item_7 = item_8 = item_9 = item_10 = item_11 = item_12 = item_13 = item_14 = item_15 = item_16 = item_17 = item_18 = item_19 = item_20 = item_21 = item_22 = item_23 = item_24 = item_25 = item_26 = item_27 = item_28 = item_29 = item_30 = item_31 = item_32 = item_33 = item_34 = item_35 = item_36 = item_37 = item_38 = item_39 = item_40 = item_41 = item_42 = item_43 = item_44 = item_45 = item_46 = item_47 = item_48 = item_49 = item_50 = item_51 = item_52 = item_53 = item_54 = item_55 = item_56 = item_57 = item_58 = item_59 = item_60 = item_61 = item_62 = item_63 = item_64 = item_65 = item_66 = item_67 = item_68 = item_69 = item_70 = item_71 = item_72 = item_73 = item_74 = item_75 = item_76 = item_77 = item_78 = item_79 = item_80 = item_81 = item_82 = item_83 = item_84 = item_85 = item_86 = item_87 = item_88 = item_89 = item_90 = item_91 = item_92 = item_93 = item_94 = item_95 = item_96 = item_97 = item_98 = item_99 = item_100 = item_101 = item_102 = item_103 = item_104 = item_105 = item_106 = item_107 = item_108 = item_109 = item_110 = item_111 = item_112 = item_113 = item_114 = item_115 = item_116 = item_117 = item_118 = item_119 = item_120 = item_121 = item_122 = item_123 = item_124 = item_125 = item_126 = item_127 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:777 in local_mapped_region, code: torch._check(routed_output.shape[0] == shape[0] * shape[1])
size_3 = routed_output.size()
getitem_139 = size_3[1];  size_3 = getitem_139 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:779 in local_mapped_region, code: routed_output = routed_output.view(shape)
routed_output_1: "bf16[4, 6144, 2048]" = routed_output.view((4, 6144, 2048));  routed_output = None

# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:781 in local_mapped_region, code: out = out.scatter_add(dim=1, index=token_indices_experts_sorted, src=routed_output)
out_3: "bf16[4, 1024, 2048]" = out_2.scatter_add(dim = 1, index = token_indices_experts_sorted_2, src = routed_output_1);  out_2 = token_indices_experts_sorted_2 = routed_output_1 = None

1. Unbacked symints contained within the HOP body

Based on 9b2974e and 36030e0.

We disable proxy mode so that unbacked symints that are contained within the HOP subgraph aren't proxied:

[rank0]: RuntimeError: u576 + u577 + u578 + u579 + u580 + u581 + u582 + u583 + u584 + u585 + u586 + u587 + u588 + u589 + u590 + u591 + u592 + u593 + u594 + u595 + u596 + u597 + u598 + u599 + u600 + u601 + u602 + u603 + u604 + u605 + u606 + u607 + u608 + u609 + u610 + u611 + u612 + u613 + u614 + u615 + u616 + u617 + u618 + u619 + u620 + u621 + u622 + u623 + u624 + u625 + u626 + u627 + u628 + u629 + u630 + u631 + u632 + u633 + u634 + u635 + u636 + u637 + u638 + u639 + 1 (140667108386064)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7fef9d44f950>

And we ensure that no unbacked symints leak outside of the region.

2. Saved symint activations

local_map is using the partitioned backward, and needs to follow the partitioner's desired ordering, this is the same order as AOTAutograd runtime wrapper uses in _backward_prologue_functional where we pass symints first:

all_args = [
*ctx_symints,
*ctx_saved_tensors,

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 15, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165551

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit a6a8294 with merge base 39a70ce (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

xmfan added a commit that referenced this pull request Oct 15, 2025
@xmfan xmfan added the topic: not user facing topic category label Oct 15, 2025
xmfan added a commit that referenced this pull request Oct 16, 2025
xmfan added a commit that referenced this pull request Oct 16, 2025
@xmfan xmfan changed the title [hop] for partitioned hops, pass symints before tensor activations [hop] local_map, for partitioned hops pass symints before tensor activations Oct 16, 2025
@xmfan xmfan changed the title [hop] local_map, for partitioned hops pass symints before tensor activations [hop] local_map, pass symints before tensor activations for the partitioned backwards Oct 16, 2025
)
ctx.pos = list(
reversed(ctx.pos)
) # make saved_tensors_and_symints return symints first
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.pos is a list of positional indexes based on the forward:

for arg in args:
idx = 0 if isinstance(arg, torch.Tensor) else 1
partitioned_args[idx].append(arg)
pos.append(idx)

ctx.pos[i] is 0 for tensors and 1 for others. We aren't dealing with more than just tensors and symints.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you help me understand why we need to do this specifically for local_map, in a way that isn't handled generically by the helpers in that linked file? (saved_tensors_and_symints and friends).

I guess specifically it's not clear to me why we need to reverse the order of activations here vs at the time that we generated them in the forward output graph

Copy link
Member Author

@xmfan xmfan Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other HOPs that use saved_tensors_and_symints are running the entire joint for their backward, so their input signature matches what the joint they traced in the forward.

But local_map is using the partitioned backward, and needs a different ordering, this is the same order as AOTAutograd runtime wrapper uses in _backward_prologue_functional:

all_args = [
*ctx_symints,
*ctx_saved_tensors,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kinda strange to me that the output order convention we have is tensor,symints for the HOPs that we desugar, vs symints,tensor for AOTAutograd. But I guess this is pre-existing?

@xmfan xmfan requested review from bdhirsh and ydwu4 October 16, 2025 14:57
@xmfan xmfan marked this pull request as ready for review October 16, 2025 14:57
@xmfan xmfan requested a review from zou3519 as a code owner October 16, 2025 14:57
…r the partitioned backwards"

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Oct 26, 2025
ghstack-source-id: e793729
Pull Request resolved: #165551
@xmfan xmfan changed the title [hop] local_map, pass symints before tensor activations for the partitioned backwards [hop] local_map MoE: fix symint activations Oct 26, 2025
This PR fixes 2 issues with local_mapping token-choice moe. Splits from the fw token dispatch result in tensors with unbacked shapes and these unbacked shapes are fully contained in the a2as, and should not leak outside of the joint graph. The HOP body fw and bw are expected to coerce back to static shapes (due to adding it with shared experts output) before returning.
```python
routed_output: "bf16[u0 + u1 + u10 + u11 + u12 + u13 + u14 + u15 + u16 + u17 + u18 + u19 + u2 + u20 + u21 + u22 + u23 + u24 + u25 + u26 + u27 + u28 + u29 + u3 + u30 + u31 + u32 + u33 + u34 + u35 + u36 + u37 + u38 + u39 + u4 + u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + u48 + u49 + u5 + u50 + u51 + u52 + u53 + u54 + u55 + u56 + u57 + u58 + u59 + u6 + u60 + u61 + u62 + u63 + u7 + u8 + u9, 2048]" = torch.ops.higher_order.autograd_function_apply(fwd_body_1, bwd_body_1, out_1, item, item_1, item_2, item_3, item_4, item_5, item_6, item_7, item_8, item_9, item_10, item_11, item_12, item_13, item_14, item_15, item_16, item_17, item_18, item_19, item_20, item_21, item_22, item_23, item_24, item_25, item_26, item_27, item_28, item_29, item_30, item_31, item_32, item_33, item_34, item_35, item_36, item_37, item_38, item_39, item_40, item_41, item_42, item_43, item_44, item_45, item_46, item_47, item_48, item_49, item_50, item_51, item_52, item_53, item_54, item_55, item_56, item_57, item_58, item_59, item_60, item_61, item_62, item_63, item_64, item_65, item_66, item_67, item_68, item_69, item_70, item_71, item_72, item_73, item_74, item_75, item_76, item_77, item_78, item_79, item_80, item_81, item_82, item_83, item_84, item_85, item_86, item_87, item_88, item_89, item_90, item_91, item_92, item_93, item_94, item_95, item_96, item_97, item_98, item_99, item_100, item_101, item_102, item_103, item_104, item_105, item_106, item_107, item_108, item_109, item_110, item_111, item_112, item_113, item_114, item_115, item_116, item_117, item_118, item_119, item_120, item_121, item_122, item_123, item_124, item_125, item_126, item_127, args_tensor_mask = [True, False, False, False], non_differentiable_idx = []);  fwd_body_1 = bwd_body_1 = out_1 = item = item_1 = item_2 = item_3 = item_4 = item_5 = item_6 = item_7 = item_8 = item_9 = item_10 = item_11 = item_12 = item_13 = item_14 = item_15 = item_16 = item_17 = item_18 = item_19 = item_20 = item_21 = item_22 = item_23 = item_24 = item_25 = item_26 = item_27 = item_28 = item_29 = item_30 = item_31 = item_32 = item_33 = item_34 = item_35 = item_36 = item_37 = item_38 = item_39 = item_40 = item_41 = item_42 = item_43 = item_44 = item_45 = item_46 = item_47 = item_48 = item_49 = item_50 = item_51 = item_52 = item_53 = item_54 = item_55 = item_56 = item_57 = item_58 = item_59 = item_60 = item_61 = item_62 = item_63 = item_64 = item_65 = item_66 = item_67 = item_68 = item_69 = item_70 = item_71 = item_72 = item_73 = item_74 = item_75 = item_76 = item_77 = item_78 = item_79 = item_80 = item_81 = item_82 = item_83 = item_84 = item_85 = item_86 = item_87 = item_88 = item_89 = item_90 = item_91 = item_92 = item_93 = item_94 = item_95 = item_96 = item_97 = item_98 = item_99 = item_100 = item_101 = item_102 = item_103 = item_104 = item_105 = item_106 = item_107 = item_108 = item_109 = item_110 = item_111 = item_112 = item_113 = item_114 = item_115 = item_116 = item_117 = item_118 = item_119 = item_120 = item_121 = item_122 = item_123 = item_124 = item_125 = item_126 = item_127 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:777 in local_mapped_region, code: torch._check(routed_output.shape[0] == shape[0] * shape[1])
size_3 = routed_output.size()
getitem_139 = size_3[1];  size_3 = getitem_139 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:779 in local_mapped_region, code: routed_output = routed_output.view(shape)
routed_output_1: "bf16[4, 6144, 2048]" = routed_output.view((4, 6144, 2048));  routed_output = None

# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:781 in local_mapped_region, code: out = out.scatter_add(dim=1, index=token_indices_experts_sorted, src=routed_output)
out_3: "bf16[4, 1024, 2048]" = out_2.scatter_add(dim = 1, index = token_indices_experts_sorted_2, src = routed_output_1);  out_2 = token_indices_experts_sorted_2 = routed_output_1 = None
```

## 1. Unbacked symints

Based on 9b2974e and 36030e0.

We disable proxy mode so that unbacked symints that are contained within the HOP subgraph aren't proxied:
```python
[rank0]: RuntimeError: u576 + u577 + u578 + u579 + u580 + u581 + u582 + u583 + u584 + u585 + u586 + u587 + u588 + u589 + u590 + u591 + u592 + u593 + u594 + u595 + u596 + u597 + u598 + u599 + u600 + u601 + u602 + u603 + u604 + u605 + u606 + u607 + u608 + u609 + u610 + u611 + u612 + u613 + u614 + u615 + u616 + u617 + u618 + u619 + u620 + u621 + u622 + u623 + u624 + u625 + u626 + u627 + u628 + u629 + u630 + u631 + u632 + u633 + u634 + u635 + u636 + u637 + u638 + u639 + 1 (140667108386064)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7fef9d44f950>
```

And we also clear the pending symbols to prevent unbacked symints contained within the HOP subgraph from leaking outside
```python
[rank0]:   File "/home/xmfan/core/a/pytorch/torch/fx/experimental/proxy_tensor.py", line 2569, in _set_unbacked_bindings
[rank0]:     if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
[rank0]:                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/xmfan/core/a/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1329, in compute_unbacked_bindings
[rank0]:     raise PendingUnbackedSymbolNotFound(
[rank0]: torch.fx.experimental.symbolic_shapes.PendingUnbackedSymbolNotFound: Pending unbacked symbols {u423, u413, u442, u479, u453, u461, u450, u485, u416, u384, u487, u397, u439, u389, u476, u401, u404, u469, u497, u385, u470, u443, u426, u465, u483, u434, u494, u448, u436, u405, u407, u402, u409, u506, u444, u457, u489, u438, u511, u446, u437, u493, u484, u412, u395, u410, u411, u390, u408, u422, u403, u456, u482, u495, u417, u406, u435, u428, u477, u452, u473, u440, u468, u463, u420, u502, u475, u431, u488, u429, u490, u486, u414, u508, u388, u433, u430, u505, u432, u472, u509, u400, u467, u391, u451, u399, u503, u393, u480, u462, u474, u449, u499, u498, u427, u478, u386, u425, u447, u500, u510, u491, u441, u394, u496, u445, u396, u492, u455, u507, u460, u392, u419, u424, u504, u471, u415, u387, u458, u466, u398, u459, u464, u481, u421, u454, u418, u501} not in returned outputs [FakeTensor(..., device='cuda:0', size=(4, 1024, 6), dtype=torch.int64), FakeTensor(..., device='cuda:0', size=(4, 1024, 6), requires_grad=True), FakeTensor(..., device='cuda:0', size=(4, 1024, 2048), dtype=torch.bfloat16,
```

## 2. Unbacked symints

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Oct 27, 2025
ghstack-source-id: b9be9ac
Pull Request resolved: #165551
@xmfan xmfan requested a review from bobrenjc93 October 27, 2025 01:03
@xmfan xmfan changed the title [hop] local_map MoE: fix symint activations [hop] local_map MoE: fix unbacked symints during tracing and symint activations Oct 27, 2025
@xmfan xmfan changed the title [hop] local_map MoE: fix unbacked symints during tracing and symint activations [hop] local_map MoE: fix unbacked symints during tracing and symint activations order in the wrapper Oct 27, 2025
num_activations = (
len(new_fw_gm.graph.find_nodes(op="output")[0].args[0]) - num_fw_outputs
)
# tensors first, then symints
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand this comment. Did you mean to keep it here?

@xmfan
Copy link
Member Author

xmfan commented Oct 27, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 27, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable)

Details for Dev Infra team Raised by workflow job

@xmfan
Copy link
Member Author

xmfan commented Oct 27, 2025

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: trunk / macos-py3-arm64 / test (default, 1, 3, macos-m1-stable)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

…nd symint activations order in the wrapper"


This PR fixes 2 issues with local_mapping token-choice moe. Splits from the fw token dispatch result in tensors with unbacked shapes and these unbacked shapes are fully contained in the a2as, and should not leak outside of the joint graph. The HOP body fw and bw are expected to coerce back to static shapes (due to adding it with shared experts output) before returning.
```python
routed_output: "bf16[u0 + u1 + u10 + u11 + u12 + u13 + u14 + u15 + u16 + u17 + u18 + u19 + u2 + u20 + u21 + u22 + u23 + u24 + u25 + u26 + u27 + u28 + u29 + u3 + u30 + u31 + u32 + u33 + u34 + u35 + u36 + u37 + u38 + u39 + u4 + u40 + u41 + u42 + u43 + u44 + u45 + u46 + u47 + u48 + u49 + u5 + u50 + u51 + u52 + u53 + u54 + u55 + u56 + u57 + u58 + u59 + u6 + u60 + u61 + u62 + u63 + u7 + u8 + u9, 2048]" = torch.ops.higher_order.autograd_function_apply(fwd_body_1, bwd_body_1, out_1, item, item_1, item_2, item_3, item_4, item_5, item_6, item_7, item_8, item_9, item_10, item_11, item_12, item_13, item_14, item_15, item_16, item_17, item_18, item_19, item_20, item_21, item_22, item_23, item_24, item_25, item_26, item_27, item_28, item_29, item_30, item_31, item_32, item_33, item_34, item_35, item_36, item_37, item_38, item_39, item_40, item_41, item_42, item_43, item_44, item_45, item_46, item_47, item_48, item_49, item_50, item_51, item_52, item_53, item_54, item_55, item_56, item_57, item_58, item_59, item_60, item_61, item_62, item_63, item_64, item_65, item_66, item_67, item_68, item_69, item_70, item_71, item_72, item_73, item_74, item_75, item_76, item_77, item_78, item_79, item_80, item_81, item_82, item_83, item_84, item_85, item_86, item_87, item_88, item_89, item_90, item_91, item_92, item_93, item_94, item_95, item_96, item_97, item_98, item_99, item_100, item_101, item_102, item_103, item_104, item_105, item_106, item_107, item_108, item_109, item_110, item_111, item_112, item_113, item_114, item_115, item_116, item_117, item_118, item_119, item_120, item_121, item_122, item_123, item_124, item_125, item_126, item_127, args_tensor_mask = [True, False, False, False], non_differentiable_idx = []);  fwd_body_1 = bwd_body_1 = out_1 = item = item_1 = item_2 = item_3 = item_4 = item_5 = item_6 = item_7 = item_8 = item_9 = item_10 = item_11 = item_12 = item_13 = item_14 = item_15 = item_16 = item_17 = item_18 = item_19 = item_20 = item_21 = item_22 = item_23 = item_24 = item_25 = item_26 = item_27 = item_28 = item_29 = item_30 = item_31 = item_32 = item_33 = item_34 = item_35 = item_36 = item_37 = item_38 = item_39 = item_40 = item_41 = item_42 = item_43 = item_44 = item_45 = item_46 = item_47 = item_48 = item_49 = item_50 = item_51 = item_52 = item_53 = item_54 = item_55 = item_56 = item_57 = item_58 = item_59 = item_60 = item_61 = item_62 = item_63 = item_64 = item_65 = item_66 = item_67 = item_68 = item_69 = item_70 = item_71 = item_72 = item_73 = item_74 = item_75 = item_76 = item_77 = item_78 = item_79 = item_80 = item_81 = item_82 = item_83 = item_84 = item_85 = item_86 = item_87 = item_88 = item_89 = item_90 = item_91 = item_92 = item_93 = item_94 = item_95 = item_96 = item_97 = item_98 = item_99 = item_100 = item_101 = item_102 = item_103 = item_104 = item_105 = item_106 = item_107 = item_108 = item_109 = item_110 = item_111 = item_112 = item_113 = item_114 = item_115 = item_116 = item_117 = item_118 = item_119 = item_120 = item_121 = item_122 = item_123 = item_124 = item_125 = item_126 = item_127 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:777 in local_mapped_region, code: torch._check(routed_output.shape[0] == shape[0] * shape[1])
size_3 = routed_output.size()
getitem_139 = size_3[1];  size_3 = getitem_139 = None
        
# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:779 in local_mapped_region, code: routed_output = routed_output.view(shape)
routed_output_1: "bf16[4, 6144, 2048]" = routed_output.view((4, 6144, 2048));  routed_output = None

# File: /home/xmfan/core/a/autoparallel/examples/example_ds3_local_map.py:781 in local_mapped_region, code: out = out.scatter_add(dim=1, index=token_indices_experts_sorted, src=routed_output)
out_3: "bf16[4, 1024, 2048]" = out_2.scatter_add(dim = 1, index = token_indices_experts_sorted_2, src = routed_output_1);  out_2 = token_indices_experts_sorted_2 = routed_output_1 = None
```

## 1. Unbacked symints contained within the HOP body

Based on 9b2974e and 36030e0.

We disable proxy mode so that unbacked symints that are contained within the HOP subgraph aren't proxied:
```python
[rank0]: RuntimeError: u576 + u577 + u578 + u579 + u580 + u581 + u582 + u583 + u584 + u585 + u586 + u587 + u588 + u589 + u590 + u591 + u592 + u593 + u594 + u595 + u596 + u597 + u598 + u599 + u600 + u601 + u602 + u603 + u604 + u605 + u606 + u607 + u608 + u609 + u610 + u611 + u612 + u613 + u614 + u615 + u616 + u617 + u618 + u619 + u620 + u621 + u622 + u623 + u624 + u625 + u626 + u627 + u628 + u629 + u630 + u631 + u632 + u633 + u634 + u635 + u636 + u637 + u638 + u639 + 1 (140667108386064)is not tracked with proxy for <torch.fx.experimental.proxy_tensor.PythonKeyTracer object at 0x7fef9d44f950>
```
And we ensure that no unbacked symints leak outside of the region.

## 2. Saved symint activations

local_map is using the partitioned backward, and needs to follow the partitioner's desired ordering, this is the same order as AOTAutograd runtime wrapper uses in `_backward_prologue_functional` where we pass symints first: https://github.com/pytorch/pytorch/blob/d2c82bafb7086a1dd109a0a6407ca7fed27337f4/torch/_functorch/_aot_autograd/runtime_wrappers.py#L1702-L1704

[ghstack-poisoned]
xmfan added a commit that referenced this pull request Oct 27, 2025
ghstack-source-id: 6b3cca8
Pull Request resolved: #165551
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job


# propagate local_map args to the call_function node
out_proxy.node.meta["local_map_kwargs"] = local_map_kwargs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to the PR, but out of curiosity - what's the plan for properly installing local map regions as subgraphs?

# TODO: get rid of this when we can install as a subgraph

(context: I tried running the new test locally and noticed I couldn't easily see the inner local map region in the GraphModule)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need

torch ir -> predispatch -> post dispatch to match torch ir -> post dispatch, but it's not always the case. I think I had issues with custom autograd functions or something else

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔

num_fwd_outputs=num_fw_outputs,
static_lifetime_input_indices=[],
)
with disable_proxy_modes_tracing():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind adding a comment explaining why we need to disable proxy tracing when we run the partitioner here? (it's not actually clear to me why this is necessary)

Copy link
Member Author

@xmfan xmfan Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand it, but it boils down to this code path:

if get_proxy_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)

which assumes that under a proxy mode, the symints must have a corresponding proxy. but this doesn't seem to be true for the some tensor nodes with compositional shapes e.g. u0+u1 has no proxy, even though u0 and u1 have proxies

input_split_sizes = output_split_sizes

tensor = torch.ops._c10d_functional.all_to_all_single(
self, output_split_sizes, input_split_sizes, group_name
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the status of "runtime asserts that are generated inside of a subgraph"? (maybe @bobrenjc93 knows). In this test, does the generated aot_eager code and/or inductor code end up with the proper runtime assert in it that the sum of the split sizes is static?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should after #165893. Maybe @xmfan can generate a tlparse to verify?

Copy link
Member Author

@xmfan xmfan Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they aren't handled properly, and they're still pending generation after the graph capture and trips on an assert, it's why we had to use ignore_fresh_unbacked_symbols

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bobrenjc93 that PR looks like it adds support only for Dynamo, I'm guessing this is only for symints that leak outside of the HOP? In the local_map case, the symints only appear during the HOP joint trace and partition

@xmfan
Copy link
Member Author

xmfan commented Oct 28, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants