-
Notifications
You must be signed in to change notification settings - Fork 25.9k
Fix SyncBatchNorm for empty inputs #74944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
TODO: 1. avoid copying count_all to CPU if possible 2. it's not crashed any more, but the output is nan Next step will try to move the fix to the CUDA kernel of `batch_norm_gather_stats_with_counts` accordingly [ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit d5f20a8 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
datumbox
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the investigation @mrshenli. I've had a look as well and added a few comments. Let me know your thoughts.
[ghstack-poisoned]
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
datumbox
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change @mrshenli.
Overall the approach looks good to me. I've added minor comments for nits. I'm currently testing this patch on a cluster using real data and it seems that the problem is resolved. If something breaks, I'll let you know.
| combined = torch.cat([mean, invstd, count], dim=0) | ||
| else: | ||
| # for empty input, directly set all stats to 0 | ||
| combined = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't something like: torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1) also work and reduce the bandwidth that is wasted?
Not sure how the rpc is handling non-contiguous Tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.zeros(dtype=input.dtype, device=input.device).expand(2 * num_channels + 1)
Curious, what bandwidth does the above code save? And why RPC is relevant here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "combined" Tensor is shared with all other nodes during the all reduce below right?
And while the Tensor in the code today has 2 * num_channels + 1 elements (that need to go through the wire), the expanded version has 1 element. So if it is sent over the wire effectively, you save a lot of bandwidth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see. Not sure if this gonna work. Collectives use ProcessGroup and will call NCCL APIs under the hood. IIRC, NCCL expects contiguous tensors and will directly read numel() elements from the memory pointer. Let me double check on that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File "/raid/shenli/pytorch/torch/distributed/distributed_c10d.py", line 2130, in _all_gather_base
work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: Tensors must be contiguous
Exception raised from check_gpu_single_tensor at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1227 (most recent call first):
Hit the above error, caused by the following line.
pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Lines 1226 to 1228 in 835cc66
| if (!tensor.is_contiguous()) { | |
| TORCH_CHECK(false, "Tensors must be contiguous"); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok then.
As a side note, I think you should look into that as it is potentially a major bandwidth gain (and if I understand correctly, this is an expensive commodity).
| num_channels = saved_input.shape[1] | ||
| if self.needs_input_grad[0]: | ||
| # launch all_reduce to unblock other peer processes | ||
| combined = torch.zeros( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question about expanded Tensor to reduce bandwidth use
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
|
|
|
|
||
| # input does not requires grad | ||
| x.requires_grad = False | ||
| self._test_not_nan(model, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I agree. It's not going to be the same gradient because the minibatch statistics will be different in the two cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. ghstack-source-id: b060e51 Pull Request resolved: #74944
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: [D35273409](https://our.internmc.facebook.com/intern/diff/D35273409) [ghstack-poisoned]
fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. ghstack-source-id: d59971b Pull Request resolved: #74944
|
@mrshenli has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
datumbox
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM from my side. My tests on real-data show that the issue is fixed.
albanD
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM
Summary: Pull Request resolved: #74944 fixes #36530 Prior to this commit, SyncBatchNorm crashes with the following error message. ``` File "..../torch/nn/modules/_functions.py", line 17, in forward mean, invstd = torch.batch_norm_stats(input, eps) RuntimeError: cannot reshape tensor of 0 elements into shape [0, 3, -1] because the unspecified dimension size -1 can be any value and is ambiguous ``` This PR adds a dedicated branch to handle empty inputs. When a process recieves empty inputs, it will set its local `mean`, `invstd`, and `count` to zero, and participate in the `all_gather` collective communications in the forward pass. Then `mean` and `invstd` with zero count will be filtered out before computing global mean and invstd. In the backward pass, it also participate in the `all_reduce` communication with zero tensors to unblock its peers. Differential Revision: D35273409 D35273409 Test Plan: Imported from OSS Reviewed By: datumbox Pulled By: mrshenli fbshipit-source-id: 1cee51eea866773c329b3fbf5da2be8a5fee6f0f
|
Hey @mrshenli. |
Stack from ghstack:
fixes #36530
Prior to this commit, SyncBatchNorm crashes with the following
error message.
This PR adds a dedicated branch to handle empty inputs. When a process
recieves empty inputs, it will set its local
mean,invstd, andcountto zero, and participate in the
all_gathercollective communications inthe forward pass. Then
meanandinvstdwith zero count will befiltered out before computing global mean and invstd. In the backward
pass, it also participate in the
all_reducecommunication with zerotensors to unblock its peers.
Differential Revision: D35273409