Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate apex.parallel.SyncBatchNorm channels_last to pytorch #46906

Closed
wants to merge 26 commits into from

Conversation

xwang233
Copy link
Collaborator

@xwang233 xwang233 commented Oct 27, 2020

per title

This PR did

  • Migrate apex.parallel.SyncBatchNorm channels_last to pytorch torch.nn.SyncBatchNorm
  • Fix a TODO here by fusing sum, div kernels into backward elementwise kernel
    # TODO: move div_ into batch_norm_backward_elemt kernel
    num_channels = sum_dy.shape[0]
    combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
    torch.distributed.all_reduce(
    combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
    sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
    divisor = count_tensor.sum()
    mean_dy = sum_dy / divisor
    mean_dy_xmu = sum_dy_xmu / divisor
    # backward pass for gradient calculation
    grad_input = torch.batch_norm_backward_elemt(
    grad_output,
    saved_input,
    mean,
    invstd,
    weight,
    mean_dy,
    mean_dy_xmu
    )

Todo

Comment: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

  • The restriction to use channels_last kernel will be like this
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}

I think we can relax that for channels_last_3d as well?

Comment: we don't have benchmark for this now, will check this and add functionality later when needed.

  • Add test
  • Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close #50781

@xwang233
Copy link
Collaborator Author

cc @ptrblck @jjsjann123 @ngimel

@dr-ci
Copy link

dr-ci bot commented Oct 27, 2020

💊 CI failures summary and remediations

As of commit 4a993f4 (more details on the Dr. CI page):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun)

Mar 02 22:39:26 [E request_callback_no_python.cpp:656] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Mar 02 22:39:26 At:
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(122): serialize
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(175): serialize
Mar 02 22:39:26 
Mar 02 22:39:26 [E request_callback_no_python.cpp:656] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Mar 02 22:39:26 
Mar 02 22:39:26 At:
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(122): serialize
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(175): serialize
Mar 02 22:39:26 
Mar 02 22:39:26 [E request_callback_no_python.cpp:656] Received error while processing request type 258: RuntimeError: Can not pickle torch.futures.Future
Mar 02 22:39:26 
Mar 02 22:39:26 At:
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(122): serialize
Mar 02 22:39:26   /opt/conda/lib/python3.6/site-packages/torch/distributed/rpc/internal.py(175): serialize
Mar 02 22:39:26 
Mar 02 22:39:26 ok (1.530s)
Mar 02 22:39:27   test_return_future_remote (__main__.TensorPipeRpcTestWithSpawn) ... ok (1.529s)
Mar 02 22:39:29   test_return_local_rrefs (__main__.TensorPipeRpcTestWithSpawn) ... ok (1.530s)
Mar 02 22:39:35   test_rpc_profiling_async_function (__main__.TensorPipeRpcTestWithSpawn) ... ok (6.137s)
Mar 02 22:39:41   test_rpc_profiling_async_function_single_threaded (__main__.TensorPipeRpcTestWithSpawn) ... ok (5.937s)

1 job timed out:

  • pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_test

🚧 1 ongoing upstream failure:

These were probably caused by upstream breakages that are not fixed yet.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@ngimel
Copy link
Collaborator

ngimel commented Oct 27, 2020

Cool, can you post benchmarks comparing to apex?

@pritamdamania87
Copy link
Contributor

pritamdamania87 commented Oct 29, 2020

@lly-zero-one Would it be possible to test this PR with some of our ClassyVision workflows to see the potential benefit?

@xwang233
Copy link
Collaborator Author

The detailed benchmark and raw data is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last.

For 2D and 4D tensors on V100 x8 (relative perf is similar on A100 x8), the kernel execution time (not including NCCL reduction/gather, kernel launch overhead, or tensor memory format transformation):

new channels_last vs master contiguous

new channels_last vs apex channels_last

@xwang233 xwang233 changed the title [WIP] Migrate apex.parallel.SyncBatchNorm channels_last to pytorch Migrate apex.parallel.SyncBatchNorm channels_last to pytorch Oct 30, 2020
@albanD albanD requested a review from ngimel October 30, 2020 13:37
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 30, 2020
@facebook-github-bot
Copy link
Contributor

Hi @xwang233!

Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but we do not have a signature on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@lly-zero-one
Copy link
Contributor

@lly-zero-one Would it be possible to test this PR with some of our ClassyVision workflows to see the potential benefit?

@pritamdamania87 In ClassyVision flow, it is using the Apex. Maybe we should ask CV team to change their flow.

@xwang233
Copy link
Collaborator Author

xwang233 commented Nov 3, 2020

@ngimel The CLA is ready.

@VitalyFedyunin
Copy link
Contributor

Hi! Can you please rebase, thanks

@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

Copy link
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, let's wait for the tests

@ngimel
Copy link
Collaborator

ngimel commented Mar 1, 2021

bc compat error is real

@xwang233
Copy link
Collaborator Author

xwang233 commented Mar 1, 2021

bc compat error is real

Yes, it is intentional. We fused mean calculations of sum / num_channels into the apply normalization kernel, so that reduces the number of kernels launched.

@ngimel
Copy link
Collaborator

ngimel commented Mar 1, 2021

I understand, but then you should add it to exceptions in bc compat test

@ngimel
Copy link
Collaborator

ngimel commented Mar 2, 2021

I don't want to delay this PR, but consider making functions like batch_norm_backward_elemt private by renaming to _batch_norm_backward_elemt in a follow up.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@malfet merged this pull request in d30f4d1.

@osalpekar
Copy link
Member

Hey @xwang233 @ngimel , it looks like the pytorch_linux_backward_compatibility_check_test test has been failing on master since this PR was merged. Have the appropriate exceptions been added to the bc compat test in this PR?

@xwang233
Copy link
Collaborator Author

xwang233 commented Mar 4, 2021

Yes, bc compact exception was added here c6c680a

The error message you saw was due to a revert of the previous commit in master branch before this one

Mar 04 00:53:49 The PR is introducing backward incompatible changes to the operator library. Please contact PyTorch team to confirm whether this change is wanted or not. 
Mar 04 00:53:49 
Mar 04 00:53:49 Broken ops: [
Mar 04 00:53:49 	aten::_lstsq_helper(Tensor a, Tensor b, float cond, str? driver_name) -> (Tensor, Tensor, Tensor)
Mar 04 00:53:49 	aten::linalg_lstsq(Tensor self, Tensor b, float? cond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)
Mar 04 00:53:49 ]

aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
pytorchmergebot pushed a commit that referenced this pull request Nov 15, 2022
This PR enabled the use of fast channels_last kernels on SyncBatchNorm with channels_last_3d memory format.

With a small benchmark script here #88021 (comment), on V100, I got

master:
```
DDP channels_last=False, run_forward_backward, time: 0.8945400714874268 sec
DDP channels_last=True, run_forward_backward, time: 1.4736433029174805 sec
```

This PR:
```
DDP channels_last=False, run_forward_backward, time: 0.8927242755889893 sec
DDP channels_last=True, run_forward_backward, time: 0.48697471618652344 sec
```

This PR is a follow-up of #46906

Close #88021
Pull Request resolved: #88401
Approved by: https://github.com/ngimel
kulinseth pushed a commit to kulinseth/pytorch that referenced this pull request Dec 10, 2022
This PR enabled the use of fast channels_last kernels on SyncBatchNorm with channels_last_3d memory format.

With a small benchmark script here pytorch#88021 (comment), on V100, I got

master:
```
DDP channels_last=False, run_forward_backward, time: 0.8945400714874268 sec
DDP channels_last=True, run_forward_backward, time: 1.4736433029174805 sec
```

This PR:
```
DDP channels_last=False, run_forward_backward, time: 0.8927242755889893 sec
DDP channels_last=True, run_forward_backward, time: 0.48697471618652344 sec
```

This PR is a follow-up of pytorch#46906

Close pytorch#88021
Pull Request resolved: pytorch#88401
Approved by: https://github.com/ngimel
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants