New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Implement clone removal for user defined triton kernel via reinplace_scatters #111627
Conversation
… reinplace_scatters [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111627
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e5236b0 with merge base bf01a7b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… reinplace_scatters ghstack-source-id: 1d009e4c5e8ae86ea6407c737e69476ab5cd14c1 Pull Request resolved: #111627
dst.target == operator.getitem | ||
and dst.args[0].kwargs["kwargs"][dst.args[1]] == node.args[0] | ||
): | ||
dst = dst.args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little iffy about this since dst
becomes triton_kernel_wrapper_functional_proxy
here but with above check that should be safe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems a bit confusing to me, since getitem
is the actual node that we're copying into the input, not triton_kernel_wrapper_functional_proxy
. But I guess it's not wrong 🤔
@@ -627,44 +630,69 @@ def reinplace_scatters(graph): | |||
for node in reversed(graph.nodes): | |||
storage_to_nodes[get_node_storage(node)].append(node) | |||
if node.target == aten.copy_.default: | |||
copy_args_to_copy_nodes[(node.args[0], node.args[1])] = node | |||
src = node.args[0] | |||
dst = node.args[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm isn't this flipped? I think the user code dst.copy_(src)
will show up in the graph as:
torch.ops.aten.copy_.default(dst, src)
return False | ||
|
||
# Check for any uses other than current node and copy_ epilogue | ||
if len(mutated_arg.users) > 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check will miss some cases. In particular, mutated_arg
can have more than 2 users, but it can still be ok to reinplace. Example:
def f(x, out):
# x will now have **three** users: `add`, `triton_kernel`, and `copy_`
tmp = torch.add(x, 1)
triton_kernel[grid](inp=x, out=out)
return out, tmp
Hmm, I think what we probably want is: if we look at users of mutated_arg
later in the graph than the triton kernel, copy_()
should be the only user (if this is the case, then it's safe to reinplace).
I'm not sure if there's a builtin way in FX to check "number of users after a given node", but there's a util here that does something similar: https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/reinplace.py#L176
cc @Chillee (lmk if this sounds right to you)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, @Chillee and I talked about this. As you mentioned, this could be improved by doing exactly what you suggested. I could do that as follow up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the concrete check is
"are there any views of this node that are used after the mutation"
But yeah, like Oguz said, mainly didn't do it out of laziness/simplicity.
): | ||
return False | ||
|
||
if len(shared_view_nodes) > 2: # Arg aliases another node other than copy_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is probably also too conservative. if there happen to be any views of the input in the graph, then this will also be > 2 and we won't reinplace.
Handling the aliasing case is a bit tricky though (but definitely feels solvable). Do we want to try to handle reinplacing in 100% of cases as part of this PR? Followup also feels totally reasonable, since this seems more like it's just the existing state of the reinplacing pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In particular: we'd ideally like to remove the copy_()
on x
for a case like this:
def f(x, out):
x_view = x.view(-1)
triton_kernel[grid](inp=x, out=out)
out2 = x_view.mul(2)
return out, out2
But we want to make sure not to remove the copy_()
for a case like this (where we actually do mutate the alias of x later):
def f(x, out):
x_view = x.view(-1)
triton_kernel[grid](inp=x, out=out)
x_view.mul_(2)
return out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, this could be a good follow up too. For now, i kept the reinplace_scatters as is
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… reinplace_scatters ghstack-source-id: fe31eff8fda0cba26767b5c390a1ab4e16e498fe Pull Request resolved: #111627
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… reinplace_scatters ghstack-source-id: d236b8297081ddc46f2516e1a4cac3b5bd917d37 Pull Request resolved: #111627
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… reinplace_scatters ghstack-source-id: aa11aaecc4c41b90ce424237ae77192479a5b6ae Pull Request resolved: #111627
… kernel via reinplace_scatters" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler [ghstack-poisoned]
… reinplace_scatters ghstack-source-id: 80c072d80716571c7a12f00a9ad1e8cf68c7440f Pull Request resolved: #111627
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
… reinplace_scatters (pytorch#111627) Pull Request resolved: pytorch#111627 Approved by: https://github.com/jansel ghstack dependencies: pytorch#111434
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov @ColinPeppler