New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP] fix: fix for fsdp exec order pre fwd record #110138
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/110138
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 2 PendingAs of commit 51ca211 with merge base 33d8f5f (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@Edwiv could you sign the CLA when you get the chance? |
done |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
When the sharding_strategy is set to SHARD_GRAD_OP and forward_prefetch=True, during direct validation run, self.is_first_iter will always be True (because training=False, iter+1 is not executed). Additionally, the _pre_forward_order_index of the first handle entering the record_pre_forward function is 0. This causes the handle to have a False result in the if condition at line 166 when entering the record_pre_forward function again (the expected value should be True because _pre_forward_order_index has actually been assigned a value). As a result, the first handle is repetitively added to handles_pre_forward_order, leading to incorrect prefetching order.