About the reason for doing all-gather in the propagation #7123
Yeosu-expo
started this conversation in
General
Replies: 1 comment 1 reply
-
Consider a simple linear layer - in the backward pass to calculate the gradient wrt the input (say X) the weight matrix (W) is needed dx = dy . W^T Hence one needs to all gather Ws in the backward pass also. |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am a student studying DNN training engines.
Nowadays, I'm analyzing DeepSpeed, which is powered by ZeRO,
and I've been reading the ZeRO papers (e.g., ZeRO, ZeRO-Offload, ZeRO++).
Sometimes, I wonder why ZeRO performs an all-gather on all parameter partitions in both the forward and backward passes.
In my opinion, since ZeRO partitions parameters in an N-way manner (where N denotes the degree of Data Parallelism), unlike the approach used in Model Parallelism, an all-gather is required for computation.
So, my main question is: Since ZeRO only partitions parameters according to Data Parallelism, does it need to perform an all-gather of all parameter partitions for both the forward and backward passes? And if ZeRO partitioned parameters in the same way as Model Parallelism(divided by column-wise or row-wise, considering computation), perhaps an all-gather would not be necessary, right?
Beta Was this translation helpful? Give feedback.
All reactions