Skip to content

Commit 73c61ea

Browse files
angelayietaf
authored andcommitted
[inductor] Fix constant folder (#166655)
Fixes https://fb.workplace.com/groups/1028545332188949/permalink/1351999569843522/ where the resulting graph of constant folder uses a sym node which has been created later. Graph diff: https://www.internalfb.com/intern/diffing/?paste_number=2014609054 Before: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %select_18 : [num_users=1] = call_function[target=torch.ops.aten.select.int](args = (%full_65, 1, 0), kwargs = {}) %mul_2792 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%select_18, 0), kwargs = {}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %mul_2792), kwargs = {}) ``` After: ``` %full_65 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_47, 768], 1), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %full_default_1 : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_150], 0), kwargs = {dtype: torch.int64, layout: torch.strided, device: cuda:0, pin_memory: False}) %embedding_4 : [num_users=1] = call_function[target=torch.ops.aten.embedding.default](args = (%_uv__surface_embeddings_weight, %full_default_1), kwargs = {}) ... %sym_size_int_150 : [num_users=7] = call_function[target=torch.ops.aten.sym_size.int](args = (%view_193, 0), kwargs = {}) ``` I couldn't figure out a small repro for this :/ Pull Request resolved: #166655 Approved by: https://github.com/eellison
1 parent bce274d commit 73c61ea

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch/_inductor/fx_passes/joint_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ def __init__(self, gm, skip_constructors=False) -> None:
227227
self.symint_nodes = _SymHashingDict()
228228
for n in self.module.graph.nodes: # type: ignore[union-attr]
229229
if "val" in n.meta and isinstance(n.meta["val"], torch.SymInt):
230-
self.symint_nodes[n.meta["val"]] = n
230+
if n.meta["val"] not in self.symint_nodes:
231+
self.symint_nodes[n.meta["val"]] = n
231232

232233
# reference from torch/_funtorch/partitioners.py:get_default_op_list
233234
self.view_op_packets = [

0 commit comments

Comments
 (0)