Skip to content

Conversation

@IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Nov 1, 2022

Here we mark most of torch.ops.nvprims as something that can be recomputed in the backward passes (and hopefully fused).

TODO:

cc @kevinstephano @jjsjann123 @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88204

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 91c148b:

The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, minor comments


# First we trace the graph conditionally decomposing nodes
# that can be sent to the nvfuser executor
with TorchRefsNvfuserCapabilityMode():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that later we are speculatively lowering the partitioned graph to nvprim again in prims_executor?

should we just skip the speculative lowering, or is the second lowering there to catch some other decomposed op?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately yes.

Initially, I did leave this only in the partitioner code, but unfortunately, AOT has two different code paths for the case when at least one of the inputs requires grad and none requires grad, for the latter case "partition_fn" is not used at all, and there's no way to pass that information further without modifying AOT code. All inputs to the "fw_compiler" functions are detached and do not require grad, so it's not possible to determine this from within there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay with a bit of monkey patching unnecessary lowering step is avoided in 2a1103c

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good~ thx for patching this~

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick follow up question.

So now we are not lowering to nvprims after aot_autograd, does this mean post-autograd decomposition won't be lowered to nvprims now 🤯

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know exactly when post-autograd decomposition happens.. Hopefully it's not the case and we can still use it.
Can we get a CI test to guard/verify that behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use decompositions (aot_module_simplified(decompositions=None)). Even if we used it, it would happen before calling the partitioning function.

prim_gm = make_fx(func)(*joint_inputs)

# all nvprims for now
recomputable_ops = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should define this recomputable_ops inside nvfuser_prims.py, where new nvprims are added? Just to avoid accidentally adding more normalization/reduction ops and having them default in recompute.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe add compulsory markers to nvprims: this function is a reduction, and this one is a normalization, and so on.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds like a better idea. So let's put a TODO here and we can clean it up afterwards.

@IvanYashchuk IvanYashchuk marked this pull request as ready for review November 2, 2022 11:13
@IvanYashchuk
Copy link
Collaborator Author

@jansel can you please approve to help merge changes to Dynamo's nvprims_nvfuser backend?

@IvanYashchuk IvanYashchuk requested a review from jansel November 4, 2022 16:41
@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge -g

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 8, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / macos-12-py3-arm64-mps / Run MPS tests

Details for Dev Infra team Raised by workflow job

@IvanYashchuk
Copy link
Collaborator Author

@pytorchbot merge -f "mps failure is unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
Here we mark most of `torch.ops.nvprims` as something that can be recomputed in the backward passes (and hopefully fused).

TODO:
- [x] Add a test after pytorch#88186 is merged

Pull Request resolved: pytorch#88204
Approved by: https://github.com/jjsjann123, https://github.com/jansel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants