-
Notifications
You must be signed in to change notification settings - Fork 25.6k
[AOTDispatch] Return mutated inputs directly when keeping mutations #120514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/120514
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 39c69f3 with merge base 953c6c3 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` ghstack-source-id: ba0ba28 Pull Request resolved: #120514
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` [ghstack-poisoned]
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (arg0_1,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` ghstack-source-id: 37afb61 Pull Request resolved: #120514
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (arg0_1,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames [ghstack-poisoned]
Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` ghstack-source-id: 2ab5b66 Pull Request resolved: #120514
# necessary. | ||
if get_node_storage(node) in output_storages and ( | ||
get_node_storage(src) in input_storages | ||
or get_node_storage(src) in output_storages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This previous logic banned all noop eliminations where the src
and node
storages are inputs or outputs, but this is only problematic if the storages weren't expected to alias. In the failing test I saw we had node = aten.slice(argn, ...)
where argn
was an input and output to the graph because of this change. The slice op itself was not returned, so eliminating the view is not an issue.
This also generalizes further, we might have a view of a view where the second view is returned but it's still safe to eliminate the first view op because that tensor is not returned directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to check, node_storage != src_storage
is equivalent to op not being a view, right? If so, could you either leave a comment or use a variable node_is_view = node_storage == src_storage
?
if copy_node == user: | ||
# Ignore uses after the copy_ epilogue node, where the input | ||
# has already been mutated anyway | ||
if copy_node_loc is not None and copy_node_loc <= user_loc: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is required because the output node on the graph counts as a user, so was preventing reinplacing on mutated inputs.
copy_node = copy_args_to_copy_nodes.get((mutated_arg, node)) | ||
if copy_node is not None: | ||
graph.erase_node(copy_node) | ||
replace_dict[copy_node] = copy_node.args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is needed because when we have an inplace mutation on a tensor x
, make_fx
replaces all future references to x
with the inplace node. This means the aten.copy_
op now has a user that needs to be updated.
ping @bdhirsh |
@bdhirsh is on vacation for two weeks. |
Let's just wait for @bdhirsh to be back then, as this issue is not blocking anything. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reinplace changes look good to me
@pytorchbot merge -r |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Rebase failed due to Command
Raised by https://github.com/pytorch/pytorch/actions/runs/8203733237 |
…mutations" Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fixes #120242 The example from the issue now results in the graph ```python def forward(self, arg0_1, arg1_1): sin = torch.ops.aten.sin.default(arg0_1); arg0_1 = None copy_ = torch.ops.aten.copy_.default(arg1_1, sin); arg1_1 = sin = None return (copy_,) ``` and the corresponding inductor kernel eliminates the intermediate buffer completely ```python def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (5, ), (1, )) assert_size_stride(arg1_1, (5, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # Source Nodes: [sin], Original ATen: [aten.sin] stream0 = get_raw_stream(0) triton_poi_fused_sin_0.run(arg0_1, arg1_1, 5, grid=grid(5), stream=stream0) del arg0_1 return (arg1_1, ) ``` ghstack-source-id: eee04a3 Pull Request resolved: #120514
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
Fixes #120242
The example from the issue now results in the graph
and the corresponding inductor kernel eliminates the intermediate buffer
completely
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler @amjames @desertfire @chauhang