New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP very slow on multi-node training #102434
Comments
Maybe your workload is communication bound, and when you go from single-node to multi-node, FSDP's communications (all-gather / reduce-scatter) and heavily exposed on the critical path? I would recommend you collect profiler traces for both the single-node and multi-node cases. By the way, regarding:
If I am understanding correctly, the second time is actually the forward and backward time, and the third time is the optimizer step? (It would be strange if forward takes much longer than backward.) |
Yes, the second time is actually the forward and backward time. Can I ask what do |
I further tested the forward time and backward time, results are as follows:
single-node:
according to my observation is seems that both inference and backward are slowed down in multi-node training. |
FSDP uses collectives communications: all-gather for parameters and reduce-scatter for gradient reduction. The forward pass only uses all-gather, whereas the backward pass uses both all-gather and reduce-scatter (meaning twice as much communication as forward). If communication is exposed on the critical path, then it is not overlapped with computation. For multi-node, the communication may take longer due to using slower inter-node bandwidth, which may make communication more easily exposed. The times you are getting look unintuitive to me. As I mentioned before, I would recommend getting a profiler trace. Then, it will be clear what is going on. In addition, I would not recommend using |
OK, thank you. I will try to calculate time and update soon. Could I please ask about what does |
Critical path refers to the ops that actually affect your end-to-end time. An op is not on the critical path if it is fully overlapped with other ops, which can happen since communication and computation can use separate GPU resources. For example, if FSDP can all-gather the 'next' layer's parameters before finishing the 'current' layer's forward computation, then that all-gather is not on the critical path because by the time we run the 'next' layer's forward computation, we already have the parameters materialized. On the other hand, if the 'next' all-gather takes longer than the 'current' computation, then the part that is not overlapped is exposed and delays the 'next' computation. I highly recommend looking at profiler traces to make these ideas concrete. |
I further tested the speed using following code:
Results are as follows:
multi-node (16xA100):
It seems that in multi-node training, cuda event time is 10x slower. |
These are my distributed init code for possible bug checking @awgu :
Besides, I want to ask whether the |
These are results I gathered using
Then is the
I don't quite understand all these metrics, but I can see that:
|
Multi-node:
Single-node:
I do not see the matmuls taking more time (you can similarly check |
Sorry, I see it wrong, the main time consume comes from |
DDP only uses all-reduce for communication and not all-gather/reduce-scatter. All-reduce is more optimized in practice. At the same time, how much larger is your FSDP model than DDP model? This affects the communication volume. |
much smaller, the FSDP model is a LLAMA-13B model plus some linear layers, the DDP only tunes linear layers |
I am going to mark this as closed because I no evidence to suggest this is an issue with FSDP. From our discussion, it seems that you have a slow inter-node interconnect. |
Finally, I set |
Hi, JulioZhao97, do you have your full code with training with FSDP in multi-node? Can you share that with me? appreciate |
wrap model is something like this, mainly adapted from |
Thanks! btw, are there any changes in the master file of the training? I used fsdp to run successfully on single node, but now I want to run on multi-node, and I don't know how to write the code. |
As far as I concern the multi-node training and single-node training is basicly the same? If you are running in a slurm cluster, the srun parameters need to be carefully set. Besides, what is your error in multi-node training? @GasolSun36 |
here are my script: echo "NODELIST="${SLURM_NODELIST} srun train.py and I applied for 2 nodes and 4 cards on the cluster, for a total of 8 cards. the error is |
@JulioZhao97 Hi, the above problem is solved, but there is a new problem: ValueError: Using fsdp only works in distributed training. Can I see your setting in the main training file, please? My main reference is this link: https://gist.github.com/TengdaHan/1dd10d335c7ca6f13810fff41e809904 |
|
@JulioZhao97 echo "NODELIST="${SLURM_NODELIST} srun python train.py Did you mean |
Same problem for me: multinode training yields a linearly increased time........................ Can anyone help? @awgu |
Thanks for your great discussions! I also meet the problem that multi-node training is much slower than single-node. The time for two-node training takes two times longer than single-node's. My training doesn't use FSDP. Do you have some suggestions? Thank you very much! @JulioZhao97 |
I suggest you check timing using |
Maybe try Deepspeed, it works perfectly well in my 8 nodes machine. FSDP just didn't work. |
馃悰 Describe the bug
When I try to train model using torch.distributed.FullyShardedDataParallel, I found that :
when training using single-node multi-gpu (1x8A100), the training speed is normal.
when training using multi-node multi-gpu(2x8A100 or 4x8A100), the training speed is very slow.
My FSDP code is as follows:
I print out the training speed, results as as follows
(three lines, first line is load data time, second is model inference and calculate loss time, last is backward() time):
First is the speed using 4x8A100, the model inference is very slow.
Then is the speed using 1x8A100, the model inference is perfectly normal:
could someone tell me why this is happening?
My test code:
Versions
My versions:
My driver version:
cc @zhaojuanmao @mrshenli @rohan-varma @awgu
The text was updated successfully, but these errors were encountered: