-
Notifications
You must be signed in to change notification settings - Fork 25.2k
[sync BN] #14267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[sync BN] #14267
Conversation
Summary: Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group. Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API) This is a WIP: 1. only supports GPU. 2. only supports single GPU per process.
master nccl is broken. This PR requires #14244 to function |
This is the first phase of this PR. We want to have sync BN support there first. My monkey python tests runs multiple processes & communications, I used it for functional. As I cannot find an official module upstream doing similar things, feedback or hints would be greatly appreciated. |
Pinging @ssnl for visibility. |
can I ask at a very high level what strategy is using here to implement sync BN? |
@weiyangfb If i understand the question correctly, the implementation here is to: 1. calculate stats (mean/var) for local mini-batch; 2. all reduce stats across all processes to calculate global mean/var; 3. Apply element-wise normalization. Backwards follows identical logic with slightly different arithmetic. |
fixing backwards path with process group fixing linter issue
Many failed tests. Seems like I got a lemon commit in master. Will merge again later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! Structure looks good to me. Kernels seem good too, although I didn't try to understand the details (e.g., math). I have some general questions, and this definitely needs some tests.
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)] | ||
invstd_l = [invstd_all.narrow(0, i, 1) for i in range(world_size)] | ||
# using all_gather instead of all reduce so we can calculate mean/var in one go | ||
torch.distributed.all_gather(mean_l, mean, process_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more efficient to coalesce the gathered tensors first and only call one all_gather? Or is it because that they may be of different precisions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
less nccl calls so technically yes. Realistically speaking the actual communication part is tiny in the overall timeline. The synchronization might have some impact here, but since we have two consecutive all_gather/all_reduce, combining the two calls doesn't really help.
Yep. Looks like we are doing similar things here with batch_norm_update_stats(exposing batch_norm_update_stats in at::native) I need this for synchronization, IIUC, @apaszke did this to fuse second step batchnorm point-wise kernels with following point-wise ops like ReLU e.t.c.. We'll need to resolve the conflicts before merging. |
The global mean & variance can be computed with just one all_reduce:
This is the strategy other existing implementations (in mxnet, tensorflow, caffe2, MegDet) is using. And using AllReduce is supposed to be more communication-efficient than using AllGather. |
I use Welford to calculate mean/var in a single pass as well. Welford has better numerical characteristics, which is desired. The all_gather that is called here is only going to gather 1 set of intermediate mean/m2n per process, for a reasonable cluster size there shouldn't be much difference between all_gather/all_reduce |
1. fallback to batch_norm when sync is not required; 2. inplace operator to save memory; 3. swtich from narrow to use unbind.
timed out in c10d test. Saw similar test failure on other PR: #15540. Anything that I should be concerned? |
@apaszke I'm repeating Carilli's Question regarding the last review comments. |
1. added check so that SyncBatchNorm is only supported for single GPU per process run with DistributedDataParallel. 2. added utility function to convert BatchNorm layer in module to SyncBatchNorm layer
@apaszke Added check during DDP initialization so that SyncBatchNorm is only supported with single GPU per process with DDP launch. |
cc: @mrshenli @teng-li and @pietern for review. |
Launching with
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a potential for deadlock with the current calls to allreduce, as they execute in parallel with autograd, and therefore the primary DDP reduction hooks.
mean_dy.div_(world_size) | ||
torch.distributed.all_reduce( | ||
mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group) | ||
mean_dy_xmu.div_(world_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These allreduce calls can interfere with ones kicked off by DDP itself.
If autograd runs with single threaded with deterministic ordering you'll be fine, but as soon as it doesn't (e.g. there are multiple branches of the forward graph where the backward functions can be called in parallel), you'll run into deadlocks. This can be avoided by creating a new process group from the main one with new_group
and using that throughout. Note that having multiple of these sync batch norm layers running backward
in parallel can still deadlock, or worse, result in mixed up data, so for guaranteed correctness you'll have to use a separate process group per sync batch norm layer. This is not ideal and we may need to find a different solution for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a gotcha for me. Thanks a lot for pointing out the issue. I don't fully understand how the allreduce call would cause trouble in branching while DDP handles that fine. Maybe I'll ask more details about this in private channel.
Just to reiterate, for the time being, it's a safe WAR as long as I have a separate process group per sync batch norm layer and use it for both forward and backward pass.
I'll copy/create a process group inside the initializer of SyncBatchNorm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Realized that duplicating process_groups in the initializer of SyncBatchNorm would not work.
As new_group should be called by all processes in the main group, and inside the initializer or converter function, each process would only see the given process_group it belongs to for that layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have this exact problem now!
My model has many SyncBNs in parallel (I also have an atypical module where I call autograd.grad() repeatedly on two tensors whose graph has syncBNs). In short the backward()
has many all_reduce
calls and some run in parallel. With DDP the code hangs without error and with NCCL logs I can see that it is because all_reduce
hangs.
I have tried torch native SyncBN, apex syncBN, and @ppwwyyxx 's NaiveSyncBN. This happens in both torch DDP and apex DDP (even with delay_allreduce=True
).
Is there something I can do to fix it? @jjsjann123 @pietern
1. adding async_op for all_reduce calls 2. renaming variables 3. removing redundant code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
We've decided we're going to go ahead and land this, and keep an eye on it for any problems that may occur later. |
@pytorchbot rebase this please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Failures look scary, should I be concerned? |
Ah, the PR bitrotted. Just a moment please. |
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: - Summary: Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group. Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API) - User-facing api: 1. torch.nn.utils.convert_sync_batchnorm(modules, process_group=None) 2. torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, ***process_group=None***) - supported use case: DistributedDataParallel with ***single-gpu multi-process*** a. User creates model containing `torch.nn.SyncBatchNorm` layers through one of the ways listed below: 1. use layers directly: torch.nn.SyncBatchNorm(...) similar API as with torch.nn.BatchNormXd(...) with added argument `process_group` which is used to limit the scope of synchronization within each process group. Default value is None, which implies synchronization across all GPUs 2. use torch.nn.utils.convert_sync_batchnorm(modules, process_group) recursively convert all `torch.nn.BatchNormXd` into `torch.nn.SyncBatchNorm` preserving values of parameters/buffers. the utility function also allows user to specify process_group value to all converted layers. b. user wraps their model with `torch.distributed.parallel.DataParallelDistributed`, from this point, user should follow the general guidelines for DDP use guide - Error checking For use cases not supported, we error out: 1. Application launched without ddp: > import torch > sbn = torch.nn.SyncBatchNorm(10).cuda() > inp = torch.randn(5, 10, 3, 3).cuda() > sbn(inp) --> Error! > AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel 2. Application launched using DDP with multi-GPU per-process: > ddp_module = nn.parallel.DistributedDataParallel(module, device_ids=device_ids, output_device=args.local_rank) > ValueError: SyncBatchNorm is only supported for DDP with single GPU per process Pull Request resolved: pytorch/pytorch#14267 Differential Revision: D14270035 Pulled By: ezyang fbshipit-source-id: 4956d8fa565c32e9df5408d53719ff9f945f4d6d
How does torch 1.1's SyncBN feature compare to Nvidia apex library? |
it's similar in functionality, but likely better perf because it directly integrates with nn.BatchNorm |
Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group.
Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API)
torch.nn.utils.convert_sync_batchnorm(modules, process_group=None)
torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None)
DistributedDataParallel with single-gpu multi-process
a. User creates model containing
torch.nn.SyncBatchNorm
layers through one of the ways listed below:use layers directly:
torch.nn.SyncBatchNorm(...)
similar API as with torch.nn.BatchNormXd(...)
with added argument
process_group
which is used to limit the scope ofsynchronization within each process group. Default value is None, which
implies synchronization across all GPUs
use torch.nn.utils.convert_sync_batchnorm(modules, process_group)
recursively convert all
torch.nn.BatchNormXd
intotorch.nn.SyncBatchNorm
preserving values of parameters/buffers.
the utility function also allows user to specify process_group value to all
converted layers.
b. user wraps their model with
torch.distributed.parallel.DataParallelDistributed
, from this point, usershould follow the general guidelines for DDP use guide
For use cases not supported, we error out:
Application launched without ddp:
Application launched using DDP with multi-GPU per-process: