Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate unshard stream for each process group with FSDP #116611

Closed
wants to merge 8 commits into from

Conversation

snarayan21
Copy link
Contributor

@snarayan21 snarayan21 commented Jan 2, 2024

This PR gives each _FSDPState with a separate process group its own unshard stream. This addresses an issue we have been seeing with overlapping communication and computation when doing 2D parallelism with FSDP.

Essentially, the whole model is sharded across all ranks, and communication is done over the default process group which contains all ranks. We 2D parallelize the model by sharding a subset of the parameters over distinct process groups. For example, with 16 GPUs, most of the model is sharded across all 16 GPUs, but some parameters are sharded across groups of 4 GPUs, and replicated 4 times. So GPU 0 will be part of the default process group, but also another process group that includes ranks 0, 1, 2, 3. For these parameters, all-gathers and reduce-scatters are done with this new process group. We FSDP wrap the parent class with the default process group, but also wrap these parameters with the new process group.

The issue arises because FSDP normally sets the same unshard stream for all _FSDPState objects, meaning that the default computation stream will wait on any communication kernels that are currently in the unshard stream. As far as I understand, with just one process group, this is necessary because we need to wait for parameters to be all-gathered before proceeding with computation. However, with >1 process group, we find that computation will wait for parameters to be unsharded even when those parameters are not needed for that computation. In both the forwards and backwards passes, this results in the computation stream excessively waiting for all-gathers that do not need to be waited on.

This PR has a solution to the above issue. For each process group attached to any _FSDPState, we assign a new unshard stream. With this, the computation stream only waits on all-gather ops that are actually necessary for that computation to proceed. With 2D parallelism, computation that requires unsharding parameters from all ranks (using the default process group) will only wait on all-gathers that use the default process group, while computation that requires unsharding parameters from ranks in the custom process group will only wait on all-gathers from the custom process group. This eliminates bubbles where computation and communication overlap, and according to our tests, improves training efficiency by up to 25% at scale.

There are definitely many ways to address this so we wanted to get input on how to best solve this, and how this can be addressed in the upcoming FSDP rewrite and in support for 2D parallelism. Given the discussions here and here there may be some risks involved in the current approach with using multiple NCCL communicators concurrently, since each process group may have its own NCCL communicator (unsure about this though). FWIW in our testing we have not seen any hangs and training has been identical before and after the change. It also may be possible to FSDP wrap our model differently to make sure that communications are queued at the right time.

Looking forward to discussing this further and helping with a fix!

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen

@pytorch-bot pytorch-bot bot added the release notes: onnx torch.onnx related changes that should show up in the release notes label Jan 2, 2024
Copy link

pytorch-bot bot commented Jan 2, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116611

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit c1f67a6 with merge base 95a86ed (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Jan 2, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@github-actions github-actions bot added oncall: distributed Add this issue/PR to distributed oncall triage queue module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration ciflow/inductor labels Jan 2, 2024
Copy link
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix your submodules an make sure to run lintrunner on this PR.

@awgu awgu self-assigned this Jan 2, 2024
@lezcano lezcano removed their request for review January 3, 2024 01:40
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 3, 2024
@Skylion007 Skylion007 self-requested a review January 3, 2024 17:05
@awgu awgu removed module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration release notes: onnx torch.onnx related changes that should show up in the release notes labels Jan 3, 2024
awgu
awgu previously approved these changes Jan 3, 2024
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds good to me!

I recall that the concurrent NCCL communicator issues was resolved in a recent NCCL update? cc: @kwen2501

torch/distributed/fsdp/_runtime_utils.py Outdated Show resolved Hide resolved
torch/distributed/fsdp/_runtime_utils.py Outdated Show resolved Hide resolved
@awgu awgu dismissed their stale review January 3, 2024 19:24

Real failures in unit tests

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jan 8, 2024

@awgu The test linked here was failing because a Stream object didn't have the priority attribute, so I added the try-except: https://github.com/pytorch/pytorch/actions/runs/7389134737/job/20102917057?pr=116611

@awgu
Copy link
Contributor

awgu commented Jan 8, 2024

@snarayan21 Thanks for the pointer! This looks like it is failing because the dummy CPU Stream class does not have the priority attribute. We can prefer to fix that instead of adding the try/except in FSDP.

Can we change

class Stream:
"""
N.B. This class only exists to facilitate device-agnostic code
"""
def __init__(self, priority: int = -1):
pass

to set self.priority = priority?

@github-actions github-actions bot added module: cpu CPU specific problem (e.g., perf, algorithm) module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration labels Jan 8, 2024
@snarayan21
Copy link
Contributor Author

Fixed the dummy Stream class priority attribute issue in this commit.

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jan 8, 2024

I recall that the concurrent NCCL communicator issues was resolved in a recent NCCL update?

@awgu We are currently seeing some issues with deadlocks when using multiple process groups. Do you have more information on what change to NCCL resolved this, and maybe a NCCL version? That would be super helpful, thanks!

@snarayan21 snarayan21 requested a review from awgu January 8, 2024 21:24
awgu
awgu previously approved these changes Jan 8, 2024
Copy link
Contributor

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Regarding the concurrent NCCL PGs, I am not too sure. @kwen2501 do you have any insights?

@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 8, 2024
@awgu
Copy link
Contributor

awgu commented Jan 9, 2024

@snarayan21 There is still a unit test failing 🤔

python test/distributed/fsdp/test_fsdp_core.py -k test_mixture_of_experts_offload_false_none

I might need to find some time to patch your PR and look into this one. I do not have a sense of why it is failing off the top of my head. We may want to get some clarity on the deadlocks too before landing.

@awgu awgu dismissed their stale review January 9, 2024 02:34

Need to debug failing unit test

@snarayan21
Copy link
Contributor Author

@awgu one datapoint for debugging this unit test failure: seems to work for the first iteration but not the second.

@wconstab
Copy link
Contributor

wconstab commented Jan 9, 2024

Given the discussions NVIDIA/nccl#195 and NVIDIA/nccl#195 there may be some risks involved in the current approach with using multiple NCCL communicators concurrently, since each process group may have its own NCCL communicator (unsure about this though)

It should be OK to use multiple NCCL communicators for unshard operations (allgathers), provided you ensure that the allgathers go in the same order on every rank. If you don't ensure this, you could observe an allgather timeout.

@snarayan21
Copy link
Contributor Author

Running into an issues with non-deterministic loss curves with this patch. Am working on a fix and will update the PR when that happens.

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jan 17, 2024

@awgu I'm still looking into how to address this problem of computation stream waiting on all-gather operations that the computation actually doesn't require. Is there a way to address this without using multiple unshard streams? From what I can tell, there's no way to make wait_stream conditional, as in, only have computation of type A wait on all-gathers of type A, and only have computation of type B wait on all-gathers of type B, without using separate streams. I thought about using events but it seems that the wait_stream call uses record_event under the hood as well. My understanding of this is incomplete, so is there possibly a way to use waiting on particular events to solve this problem? How is this being addressed differently in per-parameter sharding (link)?

@awgu
Copy link
Contributor

awgu commented Jan 17, 2024

@snarayan21 Using CUDA events with a single unshard stream may not address this issue because waiting on an event from stream will wait for any pending work in stream from before the event was recorded, which means that the CPU thread's issue order matters.

Were you able to understand better why the fix in the PR is insufficient?

@awgu
Copy link
Contributor

awgu commented Jan 17, 2024

For per-parameter sharding, we are using CUDA events whenever possible to avoid the heavier-weight CUDA stream synchronization. However, as mentioned in the previous comment, waiting on an event will still wait for all pending work issued before that event was recorded.

One question I have is, for your use case, does the order in which all-gathers are issued by the CPU match the order in which the all-gathers are expected to be waited on by the computation stream?

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jan 17, 2024

@awgu Yeah the more I think about it and investigate, looks like the all-gather issue order is partially incorrect (in the backwards pass due to the _post_forwards calls order being wrong). But even after correcting the all-gather issue order in the backwards pass, this issue persists. I'm going to look into redoing our FSDP wrapping and see if that addresses the issue we are seeing, hopefully improving efficiency and giving deterministic loss curves.

As long as the order that kernels are queued is correct from the CPU side, and there are no unintended dependencies between communication and computation, then the stream waits should be correct, afaik. Ideally we should never see unintended waiting if the issue order is correct, right?

@wconstab
Copy link
Contributor

looks like the all-gather issue order is partially incorrect (in the backwards pass due to the _post_forwards calls order being wrong).

we dont have much control over the order that autograd's engine traverses the dag of read backward ops. if fsdp encounters a case like

b = foo(a); c = bar(a)

where foo and bar are both FSDP-wrapped modules that depend on a common input, I wonder how we deal with ensuring the backwards ordering is consistent between ranks? It's valid for autograd to pick either ordering of dFoo, dBar or dBar, dFoo since they are peers in the graph. This leads to 2 subproblems potentially-

  1. you actually care that dBar goes first. (bc thats more optimal from a comms overlap perspective)
  2. we need to have the same order on every rank, even if we don't care about (1) - we can't let rank0 and rank1 pick different orders

Do we have a solution for (2) today in FSDP or are we just not hitting this case in real models? @awgu

@snarayan21
Copy link
Contributor Author

snarayan21 commented Jan 18, 2024

Seems that we were able to resolve the core issue from our side by modifying how we do FSDP wrapping for particular modules. In light of that, making multiple unshard streams might be overkill. However, we still do see some incorrect waiting between computation and the unshard stream when using higher prefetch factors (>2), so something like this might still be useful. Our team will likely hold off on this until the new FSDP rewrite though. I can close out this issue unless people want to continue the discussion. Thanks yall! :)

@awgu
Copy link
Contributor

awgu commented Jan 18, 2024

@snarayan21 Sounds good. Let us close this for now then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request module: cpu CPU specific problem (e.g., perf, algorithm) module: fsdp module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants