-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
AsyncCollectiveTensor: prevent wait_tensor() calls on graph inputs from getting DCEd #125677
Conversation
…ompiled fw [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125677
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit edba89d with merge base ab80a59 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…ompiled fw ghstack-source-id: 5704a7ddf8560e00cd6bcba6ef026d7fd9ed8872 Pull Request resolved: #125677
…executing compiled fw" This patch is relatively low LoC but I'm not very satisfied with it (internal post coming soon) [ghstack-poisoned]
…ompiled fw ghstack-source-id: 7e72f74afe68922c80ca817948c97aaec3fc993e Pull Request resolved: #125677
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good to me! Thanks for diving into the issues and fix it quickly!
Very nice fix! Thank you! |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@wanchaol was seeing the loss eventually become NaN when compiling individual transformer blocks in torchtitan - with this patch I no longer see the NaN loss.
The problem is the following:
(1) It is possible to have graph inputs to a compiled region that are AsyncCollectiveTensors. In particular: when we compile individual transformer blocks in the llama model, the first layer (embedding layer) is run in eager mode, and it outputs an AsyncCollectiveTensor that is fed to the first transformer block
(2) ideally, we would like that AsyncCollectiveTensor graph input to desugar into a
wait_tensor()
op that shows up at the beginning of the graph.(3) the way this is supposed to happen is: AOTAutograd traces through the torch_dispatch of AsyncCollectiveTensor, tracing out a
wait_tensor()
call before dispatching to any of the other ops in the function we are tracing(4) however:
trigger_wait()
was getting called in a way where we would ignore its output (and returnself.elem
directly), which would cause thewait_tensor
ops to get DCE'd.Stack from ghstack (oldest at bottom):
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k