Skip to content

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Jun 7, 2024

Stack from ghstack (oldest at bottom):

If any inputs are mutated that require grad, even if all the outputs don't require grad, we should still run autograd with a backwards graph. This fixes two tests: test_input_mutation_alias_everything and test_view_detach.

Fixes #128035

[ghstack-poisoned]
@jamesjwu jamesjwu requested review from Chillee and ezyang as code owners June 7, 2024 17:15
Copy link

pytorch-bot bot commented Jun 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128229

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ You can merge normally! (2 Unrelated Failures)

As of commit 3e8b9bb with merge base failed to retrieve merge base, please contact dev infra:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
@jamesjwu jamesjwu added ciflow/trunk Trigger trunk jobs on your pull request topic: bug fixes topic category labels Jun 7, 2024
@jamesjwu jamesjwu changed the title Run autograd if any mutations on inputs that require grad [easy] Run autograd if any mutations on inputs that require grad Jun 7, 2024
[ghstack-poisoned]


def skipIfDynamoInput(reason, xfail=False):
def skipIfDynamoInput(reason):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My previous PR didn't actually test xfail properly here, it was just skipping them anyway. It turns out in python 3 you can't xfail tests from methods, only at the class level. So I implemented expected failure with xfail_inherited_tests instead.

@jamesjwu
Copy link
Contributor Author

jamesjwu commented Jun 8, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 10, 2024

cc @IvanKobzarev / @tugsbayasgalan - another subtle AOTAutograd correctness issue that would be good to understand (also around properly picking inference vs. training)

@bdhirsh bdhirsh requested a review from tugsbayasgalan June 10, 2024 14:09
@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 10, 2024

thanks for the fix!

TharinduRusira pushed a commit to TharinduRusira/pytorch that referenced this pull request Jun 14, 2024
…orch#128229)

If any inputs are mutated that require grad, even if all the outputs don't require grad, we should still run autograd with a backwards graph. This fixes two tests: test_input_mutation_alias_everything and test_view_detach.

Fixes pytorch#128035
Pull Request resolved: pytorch#128229
Approved by: https://github.com/aorenste
@github-actions github-actions bot deleted the gh/jamesjwu/38/head branch July 12, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants