-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate unshard stream for each process group with FSDP #116611
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116611
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit c1f67a6 with merge base 95a86ed ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix your submodules an make sure to run lintrunner on this PR.
c998892
to
80baa04
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds good to me!
I recall that the concurrent NCCL communicator issues was resolved in a recent NCCL update? cc: @kwen2501
@awgu The test linked here was failing because a |
@snarayan21 Thanks for the pointer! This looks like it is failing because the dummy CPU Can we change Lines 55 to 61 in 7073dc6
to set self.priority = priority ?
|
f3e0fdd
to
c1f67a6
Compare
Fixed the dummy |
@awgu We are currently seeing some issues with deadlocks when using multiple process groups. Do you have more information on what change to NCCL resolved this, and maybe a NCCL version? That would be super helpful, thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Regarding the concurrent NCCL PGs, I am not too sure. @kwen2501 do you have any insights?
@snarayan21 There is still a unit test failing 🤔
I might need to find some time to patch your PR and look into this one. I do not have a sense of why it is failing off the top of my head. We may want to get some clarity on the deadlocks too before landing. |
@awgu one datapoint for debugging this unit test failure: seems to work for the first iteration but not the second. |
It should be OK to use multiple NCCL communicators for unshard operations (allgathers), provided you ensure that the allgathers go in the same order on every rank. If you don't ensure this, you could observe an allgather timeout. |
Running into an issues with non-deterministic loss curves with this patch. Am working on a fix and will update the PR when that happens. |
@awgu I'm still looking into how to address this problem of computation stream waiting on all-gather operations that the computation actually doesn't require. Is there a way to address this without using multiple unshard streams? From what I can tell, there's no way to make |
@snarayan21 Using CUDA events with a single unshard stream may not address this issue because waiting on an event from Were you able to understand better why the fix in the PR is insufficient? |
For per-parameter sharding, we are using CUDA events whenever possible to avoid the heavier-weight CUDA stream synchronization. However, as mentioned in the previous comment, waiting on an event will still wait for all pending work issued before that event was recorded. One question I have is, for your use case, does the order in which all-gathers are issued by the CPU match the order in which the all-gathers are expected to be waited on by the computation stream? |
@awgu Yeah the more I think about it and investigate, looks like the all-gather issue order is partially incorrect (in the backwards pass due to the As long as the order that kernels are queued is correct from the CPU side, and there are no unintended dependencies between communication and computation, then the stream waits should be correct, afaik. Ideally we should never see unintended waiting if the issue order is correct, right? |
we dont have much control over the order that autograd's engine traverses the dag of read backward ops. if fsdp encounters a case like b = foo(a); c = bar(a) where foo and bar are both FSDP-wrapped modules that depend on a common input, I wonder how we deal with ensuring the backwards ordering is consistent between ranks? It's valid for autograd to pick either ordering of dFoo, dBar or dBar, dFoo since they are peers in the graph. This leads to 2 subproblems potentially-
Do we have a solution for (2) today in FSDP or are we just not hitting this case in real models? @awgu |
Seems that we were able to resolve the core issue from our side by modifying how we do FSDP wrapping for particular modules. In light of that, making multiple unshard streams might be overkill. However, we still do see some incorrect waiting between computation and the unshard stream when using higher prefetch factors (>2), so something like this might still be useful. Our team will likely hold off on this until the new FSDP rewrite though. I can close out this issue unless people want to continue the discussion. Thanks yall! :) |
@snarayan21 Sounds good. Let us close this for now then. |
This PR gives each
_FSDPState
with a separate process group its own unshard stream. This addresses an issue we have been seeing with overlapping communication and computation when doing 2D parallelism with FSDP.Essentially, the whole model is sharded across all ranks, and communication is done over the default process group which contains all ranks. We 2D parallelize the model by sharding a subset of the parameters over distinct process groups. For example, with 16 GPUs, most of the model is sharded across all 16 GPUs, but some parameters are sharded across groups of 4 GPUs, and replicated 4 times. So GPU 0 will be part of the default process group, but also another process group that includes ranks 0, 1, 2, 3. For these parameters, all-gathers and reduce-scatters are done with this new process group. We FSDP wrap the parent class with the default process group, but also wrap these parameters with the new process group.
The issue arises because FSDP normally sets the same unshard stream for all
_FSDPState
objects, meaning that the default computation stream will wait on any communication kernels that are currently in the unshard stream. As far as I understand, with just one process group, this is necessary because we need to wait for parameters to be all-gathered before proceeding with computation. However, with >1 process group, we find that computation will wait for parameters to be unsharded even when those parameters are not needed for that computation. In both the forwards and backwards passes, this results in the computation stream excessively waiting for all-gathers that do not need to be waited on.This PR has a solution to the above issue. For each process group attached to any
_FSDPState
, we assign a new unshard stream. With this, the computation stream only waits on all-gather ops that are actually necessary for that computation to proceed. With 2D parallelism, computation that requires unsharding parameters from all ranks (using the default process group) will only wait on all-gathers that use the default process group, while computation that requires unsharding parameters from ranks in the custom process group will only wait on all-gathers from the custom process group. This eliminates bubbles where computation and communication overlap, and according to our tests, improves training efficiency by up to 25% at scale.There are definitely many ways to address this so we wanted to get input on how to best solve this, and how this can be addressed in the upcoming FSDP rewrite and in support for 2D parallelism. Given the discussions here and here there may be some risks involved in the current approach with using multiple NCCL communicators concurrently, since each process group may have its own NCCL communicator (unsure about this though). FWIW in our testing we have not seen any hangs and training has been identical before and after the change. It also may be possible to FSDP wrap our model differently to make sure that communications are queued at the right time.
Looking forward to discussing this further and helping with a fix!
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen