Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Implement clone removal for user defined triton kernel via reinplace_scatters #111627

Closed
wants to merge 5 commits into from

Conversation

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 20, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/111627

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e5236b0 with merge base bf01a7b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

oulgen added a commit that referenced this pull request Oct 20, 2023
… reinplace_scatters

ghstack-source-id: 1d009e4c5e8ae86ea6407c737e69476ab5cd14c1
Pull Request resolved: #111627
dst.target == operator.getitem
and dst.args[0].kwargs["kwargs"][dst.args[1]] == node.args[0]
):
dst = dst.args[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little iffy about this since dst becomes triton_kernel_wrapper_functional_proxy here but with above check that should be safe

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit confusing to me, since getitem is the actual node that we're copying into the input, not triton_kernel_wrapper_functional_proxy. But I guess it's not wrong 🤔

@oulgen oulgen added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 20, 2023
@@ -627,44 +630,69 @@ def reinplace_scatters(graph):
for node in reversed(graph.nodes):
storage_to_nodes[get_node_storage(node)].append(node)
if node.target == aten.copy_.default:
copy_args_to_copy_nodes[(node.args[0], node.args[1])] = node
src = node.args[0]
dst = node.args[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm isn't this flipped? I think the user code dst.copy_(src) will show up in the graph as:

torch.ops.aten.copy_.default(dst, src)

return False

# Check for any uses other than current node and copy_ epilogue
if len(mutated_arg.users) > 2:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check will miss some cases. In particular, mutated_arg can have more than 2 users, but it can still be ok to reinplace. Example:

def f(x, out):
    # x will now have **three** users: `add`, `triton_kernel`, and `copy_`
    tmp = torch.add(x, 1)
    triton_kernel[grid](inp=x, out=out)
    return out, tmp

Hmm, I think what we probably want is: if we look at users of mutated_arg later in the graph than the triton kernel, copy_() should be the only user (if this is the case, then it's safe to reinplace).

I'm not sure if there's a builtin way in FX to check "number of users after a given node", but there's a util here that does something similar: https://github.com/pytorch/pytorch/blob/main/torch/fx/passes/reinplace.py#L176

cc @Chillee (lmk if this sounds right to you)

Copy link
Contributor Author

@oulgen oulgen Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, @Chillee and I talked about this. As you mentioned, this could be improved by doing exactly what you suggested. I could do that as follow up.

Copy link
Contributor

@Chillee Chillee Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the concrete check is

"are there any views of this node that are used after the mutation"

But yeah, like Oguz said, mainly didn't do it out of laziness/simplicity.

):
return False

if len(shared_view_nodes) > 2: # Arg aliases another node other than copy_
Copy link
Contributor

@bdhirsh bdhirsh Oct 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is probably also too conservative. if there happen to be any views of the input in the graph, then this will also be > 2 and we won't reinplace.

Handling the aliasing case is a bit tricky though (but definitely feels solvable). Do we want to try to handle reinplacing in 100% of cases as part of this PR? Followup also feels totally reasonable, since this seems more like it's just the existing state of the reinplacing pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular: we'd ideally like to remove the copy_() on x for a case like this:

def f(x, out):
    x_view = x.view(-1)
    triton_kernel[grid](inp=x, out=out)
    out2 = x_view.mul(2)
    return out, out2

But we want to make sure not to remove the copy_() for a case like this (where we actually do mutate the alias of x later):

def f(x, out):
    x_view = x.view(-1)
    triton_kernel[grid](inp=x, out=out)
    x_view.mul_(2)
    return out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, this could be a good follow up too. For now, i kept the reinplace_scatters as is

… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 20, 2023
… reinplace_scatters

ghstack-source-id: fe31eff8fda0cba26767b5c390a1ab4e16e498fe
Pull Request resolved: #111627
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 20, 2023
… reinplace_scatters

ghstack-source-id: d236b8297081ddc46f2516e1a4cac3b5bd917d37
Pull Request resolved: #111627
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 22, 2023
… reinplace_scatters

ghstack-source-id: aa11aaecc4c41b90ce424237ae77192479a5b6ae
Pull Request resolved: #111627
… kernel via reinplace_scatters"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler

[ghstack-poisoned]
oulgen added a commit that referenced this pull request Oct 22, 2023
… reinplace_scatters

ghstack-source-id: 80c072d80716571c7a12f00a9ad1e8cf68c7440f
Pull Request resolved: #111627
@oulgen oulgen added the topic: not user facing topic category label Oct 22, 2023
@oulgen
Copy link
Contributor Author

oulgen commented Oct 22, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/oulgen/10/head branch October 26, 2023 14:24
andreigh pushed a commit to andreigh/pytorch that referenced this pull request Oct 26, 2023
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants