Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model Averaging] Create a post-localSGD communication hook #61206

Closed
wants to merge 4 commits into from

Conversation

wayi1
Copy link
Contributor

@wayi1 wayi1 commented Jul 3, 2021

Stack from ghstack:

Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:

  1. For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
  2. Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.

Proposal: #59699

Differential Revision: D29523292

NOTE FOR REVIEWERS: This PR has internal Facebook specific changes or comments, please review them on Phabricator!

Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.

Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jul 3, 2021

💊 CI failures summary and remediations

As of commit 076fc6c (more details on the Dr. CI page and at hud.pytorch.org/pr/61206):


  • 2/2 failures introduced in this PR

🕵️ 2 new failures recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test1 (1/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jul 10 21:54:50 AssertionError:
Jul 10 21:54:50   File "/opt/conda/lib/python3.6/site-packages/hypothesis/core.py", line 517, in test
Jul 10 21:54:50     result = self.test(*args, **kwargs)
Jul 10 21:54:50   File "/opt/conda/lib/python3.6/site-packages/caffe2/python/operator_test/adam_test.py", line 284, in test_smart_decay_sparse_adam
Jul 10 21:54:50     input_device_options=input_device_options)
Jul 10 21:54:50   File "/opt/conda/lib/python3.6/site-packages/caffe2/python/hypothesis_test_util.py", line 669, in assertReferenceChecks
Jul 10 21:54:50     output_blob_name,
Jul 10 21:54:50   File "/opt/conda/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 1533, in assert_allclose
Jul 10 21:54:50     verbose=verbose, header=header, equal_nan=equal_nan)
Jul 10 21:54:50   File "/opt/conda/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
Jul 10 21:54:50     raise AssertionError(msg)
Jul 10 21:54:50 AssertionError: 
Jul 10 21:54:50 Not equal to tolerance rtol=0.0001, atol=0.0001
Jul 10 21:54:50 Output param is not matching the reference
Jul 10 21:54:50 Mismatched elements: 1 / 1 (100%)
Jul 10 21:54:50 Max absolute difference: 0.0001
Jul 10 21:54:50 Max relative difference: 0.066463
Jul 10 21:54:50  x: array([0.001608], dtype=float32)
Jul 10 21:54:50  y: array([0.001508], dtype=float32)
Jul 10 21:54:50 
Jul 10 21:54:50 Trying example: test_smart_decay_sparse_adam(self=<caffe2.python.operator_test.adam_test.TestAdam testMethod=test_smart_decay_sparse_adam>, inputs=[array([0.], dtype=float32),
Jul 10 21:54:50  array([0.019774], dtype=float32),

See CircleCI build pytorch_linux_xenial_py3_clang7_onnx_ort_test2 (2/2)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Jul 10 21:44:06 AssertionError:
Jul 10 21:44:06   File "/opt/conda/lib/python3.6/site-packages/hypothesis/core.py", line 517, in test
Jul 10 21:44:06     result = self.test(*args, **kwargs)
Jul 10 21:44:06   File "/opt/conda/lib/python3.6/site-packages/caffe2/python/operator_test/adam_test.py", line 284, in test_smart_decay_sparse_adam
Jul 10 21:44:06     input_device_options=input_device_options)
Jul 10 21:44:06   File "/opt/conda/lib/python3.6/site-packages/caffe2/python/hypothesis_test_util.py", line 669, in assertReferenceChecks
Jul 10 21:44:06     output_blob_name,
Jul 10 21:44:06   File "/opt/conda/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 1533, in assert_allclose
Jul 10 21:44:06     verbose=verbose, header=header, equal_nan=equal_nan)
Jul 10 21:44:06   File "/opt/conda/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
Jul 10 21:44:06     raise AssertionError(msg)
Jul 10 21:44:06 AssertionError: 
Jul 10 21:44:06 Not equal to tolerance rtol=0.0001, atol=0.0001
Jul 10 21:44:06 Output param is not matching the reference
Jul 10 21:44:06 Mismatched elements: 1 / 1 (100%)
Jul 10 21:44:06 Max absolute difference: 0.0001
Jul 10 21:44:06 Max relative difference: 0.066463
Jul 10 21:44:06  x: array([0.001608], dtype=float32)
Jul 10 21:44:06  y: array([0.001508], dtype=float32)
Jul 10 21:44:06 
Jul 10 21:44:06 Trying example: test_smart_decay_sparse_adam(self=<caffe2.python.operator_test.adam_test.TestAdam testMethod=test_smart_decay_sparse_adam>, inputs=[array([0.], dtype=float32),
Jul 10 21:44:06  array([0.019774], dtype=float32),

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@wayi1 wayi1 requested a review from rohan-varma July 7, 2021 05:05
Comment on lines +3909 to +3912
# Since we start local SGD later than the total number of 100 iterations,
# no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
state = post_localSGD.PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=1000)
self._test_ddp_hook_parity(state=state, hook=post_localSGD.post_localSGD_hook)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also test a scenario where we have different global and subprocess groups? Or is that going to be tested in a follow up PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me do it in a follow-up PR. To make a difference, I probably need to refactor _test_ddp_hook_parity first.

Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.

Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Jul 8, 2021
Pull Request resolved: #61206

Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to run global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.
ghstack-source-id: 133203989

Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.

Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!

[ghstack-poisoned]
Create a communication hook to run post-local SGD. This will be combined with model averager component to better support local SGD.

In contrast to the previous approach that runs local gradient averaging + global model averaging at each step for the first K steps, now we plan to runn global gradient averaging only for the first K steps at each step, just like normal DDP. This can give us two advantages:
1) For some optimizers, model averaging can cause discrepancy in optimizer states. If we still do global gradient averaging for the first K steps, we can defer such discrepancy until we actually start local SGD.
2) Gradient averaging at the first K steps only run one allreduce that overlaps with backward pass, so it should also be more efficient.

Proposal: #59699

Differential Revision: [D29523292](https://our.internmc.facebook.com/intern/diff/D29523292/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D29523292/)!

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 0f6876d.

BACKEND != "nccl" and BACKEND != "gloo",
"MPI backend does not support DDP communication hook on CUDA devices",
)
@unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know why we need this skip decorator? All other tests seem to be able to run without it. Is there some detail such as use of different NCCL subgroups that cause issues when there is no spawn start method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I skipped this comment before. I don't remember the reason. Here the subgroup is equivalent to the global one, so it is really strange.

@facebook-github-bot facebook-github-bot deleted the gh/SciPioneer/152/head branch July 14, 2021 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants