Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Reopen #114036] Allow "must recompute" in torch.compile + selective checkpointing (SAC) #129295

Closed
wants to merge 4 commits into from

Conversation

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jun 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129295

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 5ca28a5 with merge base aa4ee2c (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

yf225 added a commit that referenced this pull request Jun 22, 2024
…checkpointing (SAC)

ghstack-source-id: dfc5d8f7d50844b2fa49726ed114a55c014ab89e
Pull Request resolved: #129295
[ghstack-poisoned]
yf225 added a commit that referenced this pull request Jun 24, 2024
…checkpointing (SAC)

ghstack-source-id: aa8ecb15071d870e0bf14243933c5b32475349f7
Pull Request resolved: #129295
torch/_functorch/partitioners.py Show resolved Hide resolved
must_recompute(user)
and user.meta["recompute"] > node.meta["recompute"]
prefer_recompute(user)
and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this an independent bugfix? Or is this part of the "partitioner should respect prefer_recompute" change.

Mostly a curiosity question: my understanding from the comment above is that if you run code like:

checkpoint_f = checkpoint(f)
checkpoint(g) = checkpoint(g)
out = f(g(inp))

Then AC requires us to save the inputs to f (outputs of g), but in our tag-based system every node would have the recompute tag: so you need some notion of "which AC subgraph does a node belong to" (ac_graph_id) to tell the partitioner that it should save the output of the first subgraph. Is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is part of the "recompute" tag cleanup to disentangle its two meanings: 1) recompute policy, 2) AC subgraph ID, by splitting them into two tags.

yes your understanding is exactly right :)

@@ -808,8 +813,7 @@ def should_ban_recomputation(node):
return False
if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
return False
# NB: "recompute" == 0 means that must save this node.
if node.meta.get("recompute", None) == 0:
if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
Copy link
Contributor

@bdhirsh bdhirsh Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like right above this condition, if the op is a view or a lift_fresh_copy then we will not ban recomputation, even if the user marked as MUST_SAVE.

Do you think we should either error here, or prefer the user annotation instead?

Maybe a more general question: for ops that the partitioner is already strongly opinionated about whether they should be saved (e.g. randomness or view ops), should we error when a user tries to change the partitioners behavior for them? Or promise to always respect the user intent?

Copy link
Contributor

@soulitzer soulitzer Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point, I think we should error when user MUST_SAVE a view or lift_fresh, etc. rather than ignore the annotation, or silently do something less optimal.

However, generally it is tricky to absolutely respect the "MUST_SAVE" condition, e.g. in the case where some code get's DCE'd.

Or another case is if the user choose to MUST_SAVE some tensor, but let's say the backward formula reduces this value before using it, so maybe we'd rather save the value post reduction.

But these two are really just the consequence of the fact that MUST_NOT_RECOMPUTE became MUST_SAVE... maybe its worth renaming it back lol.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I feel that we should bias more on respecting user intent - if user wants "MUST_SAVE", we should try to give them MUST_SAVE as much as possible, since the user explicitly expresses it.

[ghstack-poisoned]
yf225 added a commit that referenced this pull request Jun 25, 2024
…checkpointing (SAC)

ghstack-source-id: 3f8e1d7e8618505dd2ea67537a4504f57cdfdcbb
Pull Request resolved: #129295
[ghstack-poisoned]
yf225 added a commit that referenced this pull request Jun 25, 2024
…checkpointing (SAC)

ghstack-source-id: 7a9f0491de59fc37fb98c8b9e6848b6f7524a624
Pull Request resolved: #129295
@yf225
Copy link
Contributor Author

yf225 commented Jun 25, 2024

@pytorchbot merge -f "unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants