Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncCollectiveTensor: prevent wait_tensor() calls on graph inputs from getting DCEd #125677

Closed
wants to merge 2 commits into from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented May 7, 2024

@wanchaol was seeing the loss eventually become NaN when compiling individual transformer blocks in torchtitan - with this patch I no longer see the NaN loss.

The problem is the following:

(1) It is possible to have graph inputs to a compiled region that are AsyncCollectiveTensors. In particular: when we compile individual transformer blocks in the llama model, the first layer (embedding layer) is run in eager mode, and it outputs an AsyncCollectiveTensor that is fed to the first transformer block

(2) ideally, we would like that AsyncCollectiveTensor graph input to desugar into a wait_tensor() op that shows up at the beginning of the graph.

(3) the way this is supposed to happen is: AOTAutograd traces through the torch_dispatch of AsyncCollectiveTensor, tracing out a wait_tensor() call before dispatching to any of the other ops in the function we are tracing

(4) however: trigger_wait() was getting called in a way where we would ignore its output (and return self.elem directly), which would cause the wait_tensor ops to get DCE'd.

Stack from ghstack (oldest at bottom):

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented May 7, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125677

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit edba89d with merge base ab80a59 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request May 7, 2024
…ompiled fw

ghstack-source-id: 5704a7ddf8560e00cd6bcba6ef026d7fd9ed8872
Pull Request resolved: #125677
…executing compiled fw"

This patch is relatively low LoC but I'm not very satisfied with it (internal post coming soon)




[ghstack-poisoned]
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 7, 2024
bdhirsh added a commit that referenced this pull request May 7, 2024
…ompiled fw

ghstack-source-id: 7e72f74afe68922c80ca817948c97aaec3fc993e
Pull Request resolved: #125677
@bdhirsh bdhirsh changed the title AsyncCollectiveTensor: wait on pending collectives before executing compiled fw AsyncCollectiveTensor: prevent wait_tensor() calls on graph inputs from getting DCEd May 7, 2024
@albanD albanD removed their request for review May 7, 2024 13:54
Copy link
Contributor

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good to me! Thanks for diving into the issues and fix it quickly!

@yifuwang
Copy link
Contributor

yifuwang commented May 7, 2024

Very nice fix! Thank you!

@bdhirsh bdhirsh added the release notes: distributed (dtensor) release notes category label May 8, 2024
@bdhirsh
Copy link
Contributor Author

bdhirsh commented May 8, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 8, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants