Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Mar 30, 2022

Stack from ghstack:

fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local mean, invstd, and count
to zero, and participate in the all_gather collective communications in
the forward pass. Then mean and invstd with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the all_reduce communication with zero
tensors to unblock its peers.

Differential Revision: D35273409

TODO:
1. avoid copying count_all to CPU if possible
2. it's not crashed any more, but the output is nan

Next step will try to move the fix to the CUDA kernel of
`batch_norm_gather_stats_with_counts` accordingly

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 30, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit d5f20a8 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-rocm5.0-py3.7 / test (default, 2, 2, linux.rocm.gpu) (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun)

2022-04-01T16:04:53.0503226Z FAIL [0.023s]: tes...id_sampler_cuda (__main__.TestTorchDeviceTypeCUDA)
2022-04-01T16:04:53.0264233Z   test_where_scalar_valid_combination_cuda_int32 (__main__.TestTorchDeviceTypeCUDA) ... ok (0.005s)
2022-04-01T16:04:53.0388759Z   test_where_scalar_valid_combination_cuda_int64 (__main__.TestTorchDeviceTypeCUDA) ... ok (0.012s)
2022-04-01T16:04:53.0436659Z   test_where_scalar_valid_combination_cuda_int8 (__main__.TestTorchDeviceTypeCUDA) ... ok (0.005s)
2022-04-01T16:04:53.0488375Z   test_where_scalar_valid_combination_cuda_uint8 (__main__.TestTorchDeviceTypeCUDA) ... ok (0.005s)
2022-04-01T16:04:53.0494956Z   test_cuda_vitals_gpu_only_cuda (__main__.TestVitalSignsCudaCUDA) ... [TORCH_VITAL] Dataloader.enabled		 True
2022-04-01T16:04:53.0499586Z [TORCH_VITAL] Dataloader.basic_unit_test		 TEST_VALUE_STRING
2022-04-01T16:04:53.0500440Z [TORCH_VITAL] CUDA.used		 true
2022-04-01T16:04:53.0501585Z ok (0.001s)
2022-04-01T16:04:53.0501944Z 
2022-04-01T16:04:53.0502257Z ======================================================================
2022-04-01T16:04:53.0503226Z FAIL [0.023s]: test_invalid_shapes_grid_sampler_cuda (__main__.TestTorchDeviceTypeCUDA)
2022-04-01T16:04:53.0504771Z ----------------------------------------------------------------------
2022-04-01T16:04:53.0505843Z RuntimeError: cudnn_grid_sampler_forward: ATen not compiled with cuDNN support
2022-04-01T16:04:53.0506458Z 
2022-04-01T16:04:53.0506937Z During handling of the above exception, another exception occurred:
2022-04-01T16:04:53.0510638Z 
2022-04-01T16:04:53.0511284Z Traceback (most recent call last):
2022-04-01T16:04:53.0513131Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_utils.py", line 1780, in wrapper
2022-04-01T16:04:53.0514459Z     method(*args, **kwargs)
2022-04-01T16:04:53.0515972Z   File "/opt/conda/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 376, in instantiated_test
2022-04-01T16:04:53.0517465Z     result = test(self, **param_kwargs)

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

mrshenli added a commit that referenced this pull request Mar 30, 2022
TODO:
1. avoid copying count_all to CPU if possible
2. it's not crashed any more, but the output is nan

Next step will try to move the fix to the CUDA kernel of
`batch_norm_gather_stats_with_counts` accordingly

ghstack-source-id: ff6e423
Pull Request resolved: #74944
@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 30, 2022
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the investigation @mrshenli. I've had a look as well and added a few comments. Let me know your thoughts.

mrshenli added a commit that referenced this pull request Mar 31, 2022
ghstack-source-id: a82e2dc
Pull Request resolved: #74944
@mrshenli
Copy link
Contributor Author

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change @mrshenli.

Overall the approach looks good to me. I've added minor comments for nits. I'm currently testing this patch on a cluster using real data and it seems that the problem is resolved. If something breaks, I'll let you know.

combined = torch.cat([mean, invstd, count], dim=0)
else:
# for empty input, directly set all stats to 0
combined = torch.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't something like: torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1) also work and reduce the bandwidth that is wasted?
Not sure how the rpc is handling non-contiguous Tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1)

Curious, what bandwidth does the above code save? And why RPC is relevant here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "combined" Tensor is shared with all other nodes during the all reduce below right?
And while the Tensor in the code today has 2 * num_channels + 1 elements (that need to go through the wire), the expanded version has 1 element. So if it is sent over the wire effectively, you save a lot of bandwidth.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. Not sure if this gonna work. Collectives use ProcessGroup and will call NCCL APIs under the hood. IIRC, NCCL expects contiguous tensors and will directly read numel() elements from the memory pointer. Let me double check on that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  File "/raid/shenli/pytorch/torch/distributed/distributed_c10d.py", line 2130, in _all_gather_base                                                             
    work = group._allgather_base(output_tensor, input_tensor)                                                                                                   
RuntimeError: Tensors must be contiguous                                                                                                                        
Exception raised from check_gpu_single_tensor at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1227 (most recent call first):  

Hit the above error, caused by the following line.

if (!tensor.is_contiguous()) {
TORCH_CHECK(false, "Tensors must be contiguous");
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then.
As a side note, I think you should look into that as it is potentially a major bandwidth gain (and if I understand correctly, this is an expensive commodity).

num_channels = saved_input.shape[1]
if self.needs_input_grad[0]:
# launch all_reduce to unblock other peer processes
combined = torch.zeros(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question about expanded Tensor to reduce bandwidth use

fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409)

[ghstack-poisoned]
@mrshenli mrshenli changed the title [WIP] Fix SyncBatchNorm for empty inputs Fix SyncBatchNorm for empty inputs Apr 1, 2022
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 1, 2022

ci/master label does not do anything. Did you mean ciflow/trunk?


# input does not requires grad
x.requires_grad = False
self._test_not_nan(model, x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox @albanD is there a way to test the grad value as well?

If I feed batch size 0 and 2 to the two processes, will it generate the same gradients if they receive batch size of 1 and 1 respective of the same data? I assume no, because invstd is no longer guaranteed to be the same in this case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree. It's not going to be the same gradient because the minibatch statistics will be different in the two cases.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@datumbox I don't get why you say "the minibatch statistics will be different" - in SyncBatchNorm the minibatch is all samples on all workers, so in @mrshenli 's example it's the same minibatch (2 samples) in both cases.

fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409)

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Apr 1, 2022
fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

ghstack-source-id: b060e51
Pull Request resolved: #74944
@mrshenli
Copy link
Contributor Author

mrshenli commented Apr 1, 2022

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409)

[ghstack-poisoned]
mrshenli added a commit that referenced this pull request Apr 1, 2022
fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

ghstack-source-id: d59971b
Pull Request resolved: #74944
@mrshenli
Copy link
Contributor Author

mrshenli commented Apr 1, 2022

@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM from my side. My tests on real-data show that the issue is fixed.

Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

facebook-github-bot pushed a commit that referenced this pull request Apr 1, 2022
Summary:
Pull Request resolved: #74944

fixes #36530

Prior to this commit, SyncBatchNorm crashes with the following
error message.

```
File "..../torch/nn/modules/_functions.py", line 17, in forward
    mean, invstd = torch.batch_norm_stats(input, eps)
RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous
```

This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local `mean`, `invstd`, and `count`
to zero, and participate in the `all_gather` collective communications in
the forward pass. Then `mean` and `invstd` with zero count will be
filtered out before computing global mean and invstd. In the backward
pass, it also participate in the `all_reduce` communication with zero
tensors to unblock its peers.

Differential Revision:
D35273409
D35273409

Test Plan: Imported from OSS

Reviewed By: datumbox

Pulled By: mrshenli

fbshipit-source-id: 1cee51eea866773c329b3fbf5da2be8a5fee6f0f
@github-actions
Copy link
Contributor

github-actions bot commented Apr 1, 2022

Hey @mrshenli.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (ddp) release notes category topic: bug fixes topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants