-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Add min cut partitioner for AOT+nvFuser #88204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88204
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 91c148b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, minor comments
|
|
||
| # First we trace the graph conditionally decomposing nodes | ||
| # that can be sent to the nvfuser executor | ||
| with TorchRefsNvfuserCapabilityMode(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean that later we are speculatively lowering the partitioned graph to nvprim again in prims_executor?
should we just skip the speculative lowering, or is the second lowering there to catch some other decomposed op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately yes.
Initially, I did leave this only in the partitioner code, but unfortunately, AOT has two different code paths for the case when at least one of the inputs requires grad and none requires grad, for the latter case "partition_fn" is not used at all, and there's no way to pass that information further without modifying AOT code. All inputs to the "fw_compiler" functions are detached and do not require grad, so it's not possible to determine this from within there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay with a bit of monkey patching unnecessary lowering step is avoided in 2a1103c
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good~ thx for patching this~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A quick follow up question.
So now we are not lowering to nvprims after aot_autograd, does this mean post-autograd decomposition won't be lowered to nvprims now 🤯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know exactly when post-autograd decomposition happens.. Hopefully it's not the case and we can still use it.
Can we get a CI test to guard/verify that behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use decompositions (aot_module_simplified(decompositions=None)). Even if we used it, it would happen before calling the partitioning function.
| prim_gm = make_fx(func)(*joint_inputs) | ||
|
|
||
| # all nvprims for now | ||
| recomputable_ops = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we should define this recomputable_ops inside nvfuser_prims.py, where new nvprims are added? Just to avoid accidentally adding more normalization/reduction ops and having them default in recompute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or maybe add compulsory markers to nvprims: this function is a reduction, and this one is a normalization, and so on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds like a better idea. So let's put a TODO here and we can clean it up afterwards.
|
@jansel can you please approve to help merge changes to Dynamo's |
|
@pytorchbot merge -g |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / macos-12-py3-arm64-mps / Run MPS tests Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "mps failure is unrelated" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Here we mark most of `torch.ops.nvprims` as something that can be recomputed in the backward passes (and hopefully fused). TODO: - [x] Add a test after pytorch#88186 is merged Pull Request resolved: pytorch#88204 Approved by: https://github.com/jjsjann123, https://github.com/jansel
Here we mark most of
torch.ops.nvprimsas something that can be recomputed in the backward passes (and hopefully fused).TODO:
cc @kevinstephano @jjsjann123 @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire