Skip to content

Commit

Permalink
[fx][split] Copy node metadata for placeholders (#107981)
Browse files Browse the repository at this point in the history
- Follow-up to #107248 which copies metadata for placeholder nodes in the top-level FX graph
- Currently, top-level placeholders do not have their metadata copied over, causing loss of `TensorMetadata` in some `torch.compile` backends

Fixes pytorch/TensorRT#2258
Pull Request resolved: #107981
Approved by: https://github.com/angelayi
  • Loading branch information
gs-olive authored and pytorchmergebot committed Sep 7, 2023
1 parent 56b8481 commit 6a44881
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torch/fx/passes/split_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ def flatten(x: torch.fx.node.Argument) -> NodeList:
# Placeholders in the original graph get copied to main graph.
if node.op == "placeholder":
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type)
main_remapping[node].meta = copy.copy(node.meta)
continue

# Get_attr nodes are ignored because we are not tagging them.
Expand Down

0 comments on commit 6a44881

Please sign in to comment.