Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Gradient Compression] Implement the original layerwise PowerSGD #49417

Closed
wants to merge 4 commits into from

Conversation

wayi1
Copy link
Contributor

@wayi1 wayi1 commented Dec 15, 2020

Stack from ghstack:

The existing implementation applies PowerSGD to a batch of flattened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layerwise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: D25511543

The existing implementation applies PowerSGD to a batch of flatened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layer-wise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Dec 15, 2020

💊 CI failures summary and remediations

As of commit 9e60718 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 11 times.

The existing implementation applies PowerSGD to a batch of flatened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layer-wise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Dec 15, 2020
Pull Request resolved: #49417

The existing implementation applies PowerSGD to a batch of flatened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layer-wise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 118643443

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)
@wayi1 wayi1 changed the title [Gradient Compression] Implement the original PowerSGD [Gradient Compression] Implement the original layerwise PowerSGD Dec 16, 2020
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this! I left a few comments/questions inline.

q_idx += m * matrix_approximation_rank

# Initialize and then orthogonalize Qs.
with torch.random.fork_rng(devices=[]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be reasonable to dedupe other use cases of this forking in grad compression to a helper function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had the same feeling. However, in the first implementation, it has a loop in it: for q in qs:, between setting the manual seed and filling random values. It can be a bit tricky. Let me try to do it in a separate refactoring PR.

Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for answering all my questions/comments! LGTM, we can refactor the tests in a follow up PR as you mentioned.

…werSGD"


The existing implementation applies PowerSGD to a batch of flattened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layerwise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)

[ghstack-poisoned]
…werSGD"


The existing implementation applies PowerSGD to a batch of flattened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layerwise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Dec 18, 2020
Pull Request resolved: #49417

The existing implementation applies PowerSGD to a batch of flattened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layerwise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression #47202
ghstack-source-id: 118921275

Differential Revision: [D25511543](https://our.internmc.facebook.com/intern/diff/D25511543/)
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 71f3399.

@mrshenli
Copy link
Contributor

This breaks multi-gpu test on master. reverting

https://app.circleci.com/pipelines/github/pytorch/pytorch/253510/workflows/d97063d2-eeb6-470c-88c8-7a64553078fc/jobs/9759703

Dec 19 03:33:52   test_DistributedDataParallel_powerSGD_ddp_comm_hook (__main__.TestDistBackendWithFork) ... ERROR:root:Caught exception: 
Dec 19 03:33:52 Traceback (most recent call last):
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 283, in wrapper
Dec 19 03:33:52     fn()
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 97, in wrapper
Dec 19 03:33:52     return func(*args, **kwargs)
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_distributed.py", line 156, in wrapper
Dec 19 03:33:52     return func(*args, **kwargs)
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/distributed/distributed_test.py", line 2851, in test_DistributedDataParallel_powerSGD_ddp_comm_hook
Dec 19 03:33:52     loss.backward()
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/tensor.py", line 233, in backward
Dec 19 03:33:52     torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
Dec 19 03:33:52   File "/opt/conda/lib/python3.6/site-packages/torch/autograd/__init__.py", line 146, in backward
Dec 19 03:33:52     allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
Dec 19 03:33:52 RuntimeError: AttributeError: 'object' object has no attribute 'size'
Dec 19 03:33:52 
Dec 19 03:33:52 At:
Dec 19 03:33:52   /opt/conda/lib/python3.6/site-packages/torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py(118): powerSGD_hook
Dec 19 03:33:52 
Dec 19 03:33:52 exiting process with exit code: 10
Dec 19 03:33:53 Process 2 terminated with exit code 10, terminating remaining processes.
Dec 19 03:33:53 ERROR (2.849s)

@facebook-github-bot
Copy link
Contributor

This pull request has been reverted by ad9923e.

facebook-github-bot pushed a commit that referenced this pull request Dec 20, 2020
…erSGD (#49639)

Summary:
Pull Request resolved: #49639

Resubmit #49417 with a fix for distributed_test.

The previous submission broke a multi-gpu test that runs on 4 GPUs. Since this test only runs on master, couldn't detect it before the submission.

The real diff is:
4ca1014

This time I have verified that the previous failed test `pytorch_linux_xenial_cuda10_2_cudnn7_py3_multigpu_test` could pass after creating a PR (#49651) from a separate branch:
https://app.circleci.com/pipelines/github/pytorch/pytorch/253644/workflows/c1c02b70-0877-40e6-8b4c-61f60f6b70ed/jobs/9768079

ghstack-source-id: 118969912

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook、

Reviewed By: mrshenli

Differential Revision: D25654961

fbshipit-source-id: 2a45c8ceb9bdb54ff7309a8b66ec87e913e0150e
wayi1 pushed a commit that referenced this pull request Dec 21, 2020
wayi1 pushed a commit that referenced this pull request Dec 21, 2020
Address the comment on #49417 (comment)

Differential Revision: [D25673997](https://our.internmc.facebook.com/intern/diff/D25673997/)

ghstack-source-id: 119021459
Pull Request resolved: #49715
@facebook-github-bot facebook-github-bot deleted the gh/SciPioneer/35/head branch December 22, 2020 15:17
wayi1 pushed a commit that referenced this pull request Dec 22, 2020
…use.size()"


Address the comment on #49417 (comment)

Differential Revision: [D25673997](https://our.internmc.facebook.com/intern/diff/D25673997/)

[ghstack-poisoned]
wayi1 pushed a commit that referenced this pull request Dec 22, 2020
Pull Request resolved: #49715

Address the comment on #49417 (comment)
ghstack-source-id: 119049598

Differential Revision: [D25673997](https://our.internmc.facebook.com/intern/diff/D25673997/)
facebook-github-bot pushed a commit that referenced this pull request Dec 23, 2020
…49715)

Summary:
Pull Request resolved: #49715

Address the comment on #49417 (comment)
ghstack-source-id: 119049598

Test Plan: waitforbuildbot

Reviewed By: rohan-varma

Differential Revision: D25673997

fbshipit-source-id: 44eb2540e5a77331c34ba503285cbd0bd63c2c0a
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
…orch#49417)

Summary:
Pull Request resolved: pytorch#49417

The existing implementation applies PowerSGD to a batch of flattened tensors, which is a coarse-grained compression. This hook now is renamed as "batched_powerSGD_hook".

Now implement the original implementation in the paper, which applies PowerSGD to each per-parameter tensor. This is a layerwise fine-grained compression. Although this original implementation is slower, it is expected to achieve a higher accuracy, especially when the shapes of per-param tensors cannot be aligned.

Also add a test in distributed_test.py.

Original PR issue: Investigate Applying PowerSGD to Communication Hook for Gradient Compression pytorch#47202
ghstack-source-id: 118921275

Test Plan:
buck test mode/dev-nosan caffe2/test/distributed:c10d -- test_powerSGD_ddp_comm_hook_nccl

buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook

Reviewed By: rohan-varma

Differential Revision: D25511543

fbshipit-source-id: 19ef188bc2d4c7406443c8fa233c1f2c2f27d93c
hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
…erSGD (pytorch#49639)

Summary:
Pull Request resolved: pytorch#49639

Resubmit pytorch#49417 with a fix for distributed_test.

The previous submission broke a multi-gpu test that runs on 4 GPUs. Since this test only runs on master, couldn't detect it before the submission.

The real diff is:
pytorch@4ca1014

This time I have verified that the previous failed test `pytorch_linux_xenial_cuda10_2_cudnn7_py3_multigpu_test` could pass after creating a PR (pytorch#49651) from a separate branch:
https://app.circleci.com/pipelines/github/pytorch/pytorch/253644/workflows/c1c02b70-0877-40e6-8b4c-61f60f6b70ed/jobs/9768079

ghstack-source-id: 118969912

Test Plan: buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork -- test_DistributedDataParallel_powerSGD_ddp_comm_hook、

Reviewed By: mrshenli

Differential Revision: D25654961

fbshipit-source-id: 2a45c8ceb9bdb54ff7309a8b66ec87e913e0150e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants