Skip to content

Conversation

TomoshibiAkira
Copy link
Contributor

Update the requirements on input dimensions for torch.nn.SyncBatchNorm:

  1. Checks the aggregated batch size count_all instead of batch size in every DDP process SyncBatchNorm size check #36865
  2. Added test function for SyncBatchNorm where every process only has 1 input

@dr-ci
Copy link

dr-ci bot commented Apr 23, 2020

💊 Build failures summary and remediations

As of commit 8dfd2d0 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 11 times.

@mrshenli
Copy link
Contributor

cc @zhaojuanmao

@zhangguanheng66 zhangguanheng66 added oncall: distributed Add this issue/PR to distributed oncall triage queue module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Apr 23, 2020
Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. thanks for working on it. just one minor comment.

question: did the test fail without moving to check count_all?

@@ -2135,6 +2135,56 @@ def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
)
self._barrier()

@unittest.skipIf(BACKEND != 'nccl' and BACKEND != 'gloo',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@unittest.skipIf(BACKEND != 'nccl' ....), syncNorm is only supported for nccl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you're right. But since the original SyncBatchNorm test function test_DistributedDataParallel_SyncBatchNorm() uses this flag (and I'm not sure why), I might just continue the tradition. :)

About the failing test on CI, it seems like it's related to some networking issues (weird time out) and probably not related to the code.

@zhaojuanmao
Copy link
Contributor

looks good, would you please rebase? there are some irrelevant test failures

@TomoshibiAkira TomoshibiAkira force-pushed the syncbn_size_check_fix branch from 5e7a49f to 2f71a1d Compare April 29, 2020 18:41
@TomoshibiAkira TomoshibiAkira force-pushed the syncbn_size_check_fix branch from 2f71a1d to 8dfd2d0 Compare April 29, 2020 18:51
@TomoshibiAkira
Copy link
Contributor Author

TomoshibiAkira commented Apr 30, 2020

@zhaojuanmao done. all tests are passed now.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhaojuanmao is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@zhaojuanmao merged this pull request in ae755a7.

@TomoshibiAkira TomoshibiAkira deleted the syncbn_size_check_fix branch May 3, 2020 16:05
size = count_all.view(-1).long().sum()
if size == 1:
raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

Copy link
Collaborator

@xwang233 xwang233 Oct 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ngimel @jjsjann123 @ptrblck

I found that this change to size calculation introduces the huge regression in NCCL reduction. For a sync BN forward of (64, 2048, 4, 4) float tensor on my machine with 2 GPUs, the previous code takes 0.7ms, but current code takes 2.4ms.

I think this is because the old size calculation is purely on CPU, but the new size calculation is on GPU, and needs a synchronization for the size == 1 comparison on CPU. This can be easily recovered by 1. using the old size calculation, or 2. remove the if statement.

I'm currently working on migrating the sync BN channels-last from apex, and we can discuss a fix there.

Also, I suggest that we require performance benchmark for relevant PRs in the future.

xwang233 added a commit to xwang233/pytorch that referenced this pull request Oct 27, 2020
facebook-github-bot pushed a commit that referenced this pull request Mar 3, 2021
Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in #37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close #50781

Pull Request resolved: #46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
aocsa pushed a commit to Quansight/pytorch that referenced this pull request Mar 15, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
xsacha pushed a commit to xsacha/pytorch that referenced this pull request Mar 31, 2021
…#46906)

Summary:
per title

This PR did
- Migrate `apex.parallel.SyncBatchNorm` channels_last to pytorch `torch.nn.SyncBatchNorm`
- Fix a TODO here by fusing `sum`, `div` kernels into backward elementwise kernel
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L76-L95

Todo
- [x] Discuss a regression introduced in pytorch#37133 (comment), which is the synchronized copy here
https://github.com/pytorch/pytorch/blob/b167402e2e66a663cd9913885552929b4c045ffa/torch/nn/modules/_functions.py#L32-L34

**Comment**: This PR uses apex version for the size check. Test passed and I haven't seen anything wrong so far.

- [x] The restriction to use channels_last kernel will be like this
```
inline bool batch_norm_use_channels_last_kernels(const at::Tensor& self) {
  return self.is_contiguous(at::MemoryFormat::ChannelsLast) || self.ndimension() == 2;
}
```
I think we can relax that for channels_last_3d as well?

**Comment**: we don't have benchmark for this now, will check this and add functionality later when needed.
- [x] Add test
- [x] Add benchmark

Detailed benchmark is at https://github.com/xwang233/code-snippet/tree/master/syncbn-channels-last

Close pytorch#50781

Pull Request resolved: pytorch#46906

Reviewed By: albanD

Differential Revision: D26771437

Pulled By: malfet

fbshipit-source-id: d00387044e9d43ac7e6c0e32a2db22c63d1504de
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: nn Related to torch.nn oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants