New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model Averaging] Create a post-localSGD communication hook #61206
Conversation
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)! [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 076fc6c (more details on the Dr. CI page and at hud.pytorch.org/pr/61206):
🕵️ 2 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_py3_clang7_onnx_ort_test1 (1/2)Step: "Run tests" (full log | diagnosis details | 🔁 rerun)
|
torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py
Outdated
Show resolved
Hide resolved
# Since we start local SGD later than the total number of 100 iterations, | ||
# no local SGD actually is executed, and we don't even need to provide a subgroup for this case. | ||
state = post_localSGD.PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=1000) | ||
self._test_ddp_hook_parity(state=state, hook=post_localSGD.post_localSGD_hook) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also test a scenario where we have different global and subprocess groups? Or is that going to be tested in a follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me do it in a follow-up PR. To make a difference, I probably need to refactor _test_ddp_hook_parity
first.
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)! [ghstack-poisoned]
Pull Request resolved: #61206 Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to run global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. ghstack-source-id: 133203989 Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)! [ghstack-poisoned]
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD. In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages: 1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD. 2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient. Proposal: #59699 Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)! [ghstack-poisoned]
This pull request has been merged in 0f6876d. |
BACKEND != "nccl" and BACKEND != "gloo", | ||
"MPI backend does not support DDP communication hook on CUDA devices", | ||
) | ||
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know why we need this skip decorator? All other tests seem to be able to run without it. Is there some detail such as use of different NCCL subgroups that cause issues when there is no spawn start method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry that I skipped this comment before. I don't remember the reason. Here the subgroup is equivalent to the global one, so it is really strange.
Stack from ghstack:
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.
In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
Proposal: #59699
Differential Revision: D29523292
NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!