Skip to content

Conversation

jjsjann123
Copy link
Collaborator

@jjsjann123 jjsjann123 commented Nov 21, 2018

  • Summary:

Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group.
Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API)

  • User-facing api:
  1. torch.nn.utils.convert_sync_batchnorm(modules, process_group=None)

  2. torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None)

  • supported use case:
    DistributedDataParallel with single-gpu multi-process

a. User creates model containing torch.nn.SyncBatchNorm layers through one of the ways listed below:

  1. use layers directly:

    torch.nn.SyncBatchNorm(...)

    similar API as with torch.nn.BatchNormXd(...)
    with added argument process_group which is used to limit the scope of
    synchronization within each process group. Default value is None, which
    implies synchronization across all GPUs

  2. use torch.nn.utils.convert_sync_batchnorm(modules, process_group)

    recursively convert all torch.nn.BatchNormXd into torch.nn.SyncBatchNorm
    preserving values of parameters/buffers.
    the utility function also allows user to specify process_group value to all
    converted layers.

b. user wraps their model with
torch.distributed.parallel.DataParallelDistributed, from this point, user
should follow the general guidelines for DDP use guide

  • Error checking

For use cases not supported, we error out:

  1. Application launched without ddp:

    import torch
    sbn = torch.nn.SyncBatchNorm(10).cuda()
    inp = torch.randn(5, 10, 3, 3).cuda()
    sbn(inp) --> Error!
    AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel

  2. Application launched using DDP with multi-GPU per-process:

    ddp_module = nn.parallel.DistributedDataParallel(module, device_ids=device_ids, output_device=args.local_rank)
    ValueError: SyncBatchNorm is only supported for DDP with single GPU per process

Summary:
Added synchronized batch normalization, allows synchronization of stats
across mini-batches between processes within a process group.

Current implementation uses a mixture of extended ATen native functions
(cpp cuda extension) + torch.nn.modules (c10d python API)

This is a WIP:
1. only supports GPU.
2. only supports single GPU per process.
@jjsjann123
Copy link
Collaborator Author

master nccl is broken. This PR requires #14244 to function

@jjsjann123
Copy link
Collaborator Author

This is the first phase of this PR. We want to have sync BN support there first.
We can discuss support for multiple GPU per process as well as moving communication into pure c++ as well.
One thing we need to discuss here in this PR is testing.

My monkey python tests runs multiple processes & communications, I used it for functional.
For upstream tests, I'm thinking about utilizing unit tests for batch norm layer and adding some simple test in test_distributed.py to validate the communication/parallel welford part.

As I cannot find an official module upstream doing similar things, feedback or hints would be greatly appreciated.

@jjsjann123
Copy link
Collaborator Author

Pinging @ssnl for visibility.

@weiyangfb
Copy link
Contributor

can I ask at a very high level what strategy is using here to implement sync BN?

@jjsjann123
Copy link
Collaborator Author

@weiyangfb If i understand the question correctly, the implementation here is to: 1. calculate stats (mean/var) for local mini-batch; 2. all reduce stats across all processes to calculate global mean/var; 3. Apply element-wise normalization.

Backwards follows identical logic with slightly different arithmetic.

fixing backwards path with process group
fixing linter issue
@jjsjann123
Copy link
Collaborator Author

Many failed tests. Seems like I got a lemon commit in master. Will merge again later.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Structure looks good to me. Kernels seem good too, although I didn't try to understand the details (e.g., math). I have some general questions, and this definitely needs some tests.

mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
invstd_l = [invstd_all.narrow(0, i, 1) for i in range(world_size)]
# using all_gather instead of all reduce so we can calculate mean/var in one go
torch.distributed.all_gather(mean_l, mean, process_group)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more efficient to coalesce the gathered tensors first and only call one all_gather? Or is it because that they may be of different precisions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

less nccl calls so technically yes. Realistically speaking the actual communication part is tiny in the overall timeline. The synchronization might have some impact here, but since we have two consecutive all_gather/all_reduce, combining the two calls doesn't really help.

@fmassa
Copy link
Member

fmassa commented Dec 17, 2018

Quick question: doesn't this PR have some overlap with the work from @apaszke from #15146 ? Both exposes a new batch_norm_update_stats function.

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Dec 17, 2018

Yep. Looks like we are doing similar things here with batch_norm_update_stats(exposing batch_norm_update_stats in at::native)

I need this for synchronization, IIUC, @apaszke did this to fuse second step batchnorm point-wise kernels with following point-wise ops like ReLU e.t.c..

We'll need to resolve the conflicts before merging.

@ppwwyyxx
Copy link
Collaborator

ppwwyyxx commented Dec 20, 2018

The global mean & variance can be computed with just one all_reduce:

  1. Each worker compute its mean(x) and mean(x^2)
  2. AllReduce them to get global mean(x) and mean(x^2).
  3. Compute the global variance from the global mean(x) and mean(x^2), by var(x) = mean(x^2) - mean(x)^2.

This is the strategy other existing implementations (in mxnet, tensorflow, caffe2, MegDet) is using. And using AllReduce is supposed to be more communication-efficient than using AllGather.

@jjsjann123
Copy link
Collaborator Author

I use Welford to calculate mean/var in a single pass as well. Welford has better numerical characteristics, which is desired.

The all_gather that is called here is only going to gather 1 set of intermediate mean/m2n per process, for a reasonable cluster size there shouldn't be much difference between all_gather/all_reduce

@jjsjann123
Copy link
Collaborator Author

jjsjann123 commented Dec 27, 2018

timed out in c10d test. Saw similar test failure on other PR: #15540.
Test passed on my local machine as well (cuda9_0_cudnn7)

Anything that I should be concerned?

@jjsjann123
Copy link
Collaborator Author

@apaszke I'm repeating Carilli's Question regarding the last review comments.
How could I access my parent classes from SyncBatchNorm? i.e. how would I know from within the SyncBatchNorm that the model is being launched by DDP or DP?

1. added check so that SyncBatchNorm is only supported for single GPU per
process run with DistributedDataParallel.
2. added utility function to convert BatchNorm layer in module to SyncBatchNorm
layer
@jjsjann123
Copy link
Collaborator Author

@apaszke Added check during DDP initialization so that SyncBatchNorm is only supported with single GPU per process with DDP launch.

@soumith
Copy link
Member

soumith commented Feb 8, 2019

cc: @mrshenli @teng-li and @pietern for review.
@jjsjann123 can you also check and update what happens if this is accidentally used by user in a nn.DataParallel setting

@jjsjann123
Copy link
Collaborator Author

Launching with nn.DataParallel would be the case where it's not launched through DDP, at SyncBatchNorm.forward() an exception is raise:

  1. Application launched without ddp:

import torch
sbn = torch.nn.SyncBatchNorm(10).cuda()
inp = torch.randn(5, 10, 3, 3).cuda()
sbn(inp) --> Error!
AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel

Copy link
Contributor

@pietern pietern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a potential for deadlock with the current calls to allreduce, as they execute in parallel with autograd, and therefore the primary DDP reduction hooks.

mean_dy.div_(world_size)
torch.distributed.all_reduce(
mean_dy_xmu, torch.distributed.ReduceOp.SUM, process_group)
mean_dy_xmu.div_(world_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These allreduce calls can interfere with ones kicked off by DDP itself.

If autograd runs with single threaded with deterministic ordering you'll be fine, but as soon as it doesn't (e.g. there are multiple branches of the forward graph where the backward functions can be called in parallel), you'll run into deadlocks. This can be avoided by creating a new process group from the main one with new_group and using that throughout. Note that having multiple of these sync batch norm layers running backward in parallel can still deadlock, or worse, result in mixed up data, so for guaranteed correctness you'll have to use a separate process group per sync batch norm layer. This is not ideal and we may need to find a different solution for this.

cc @mrshenli @teng-li

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a gotcha for me. Thanks a lot for pointing out the issue. I don't fully understand how the allreduce call would cause trouble in branching while DDP handles that fine. Maybe I'll ask more details about this in private channel.

Just to reiterate, for the time being, it's a safe WAR as long as I have a separate process group per sync batch norm layer and use it for both forward and backward pass.
I'll copy/create a process group inside the initializer of SyncBatchNorm.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Realized that duplicating process_groups in the initializer of SyncBatchNorm would not work.

As new_group should be called by all processes in the main group, and inside the initializer or converter function, each process would only see the given process_group it belongs to for that layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this exact problem now!
My model has many SyncBNs in parallel (I also have an atypical module where I call autograd.grad() repeatedly on two tensors whose graph has syncBNs). In short the backward() has many all_reduce calls and some run in parallel. With DDP the code hangs without error and with NCCL logs I can see that it is because all_reduce hangs.

I have tried torch native SyncBN, apex syncBN, and @ppwwyyxx 's NaiveSyncBN. This happens in both torch DDP and apex DDP (even with delay_allreduce=True).

Is there something I can do to fix it? @jjsjann123 @pietern

1. adding async_op for all_reduce calls
2. renaming variables
3. removing redundant code
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Feb 28, 2019

We've decided we're going to go ahead and land this, and keep an eye on it for any problems that may occur later.

@ezyang
Copy link
Contributor

ezyang commented Mar 1, 2019

@pytorchbot rebase this please

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@jjsjann123
Copy link
Collaborator Author

Failures look scary, should I be concerned?

@ezyang
Copy link
Contributor

ezyang commented Mar 1, 2019

Ah, the PR bitrotted. Just a moment please.

ezyang added 2 commits March 1, 2019 17:03
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Mar 6, 2019
Summary:
- Summary:

Added synchronized batch normalization, allows synchronization of stats across mini-batches between processes within a process group.
Current implementation uses a mixture of extended ATen native functions (cpp cuda extension) + torch.nn.modules (c10d python API)

- User-facing api:

1. torch.nn.utils.convert_sync_batchnorm(modules, process_group=None)

2. torch.nn.SyncBatchNorm(num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, ***process_group=None***)

- supported use case:
DistributedDataParallel with ***single-gpu multi-process***

a. User creates model containing `torch.nn.SyncBatchNorm` layers through one of the ways listed below:

  1. use layers directly:

     torch.nn.SyncBatchNorm(...)

     similar API as with torch.nn.BatchNormXd(...)
     with added argument `process_group` which is used to limit the scope of
     synchronization within each process group. Default value is None, which
     implies synchronization across all GPUs

  2. use torch.nn.utils.convert_sync_batchnorm(modules, process_group)

     recursively convert all `torch.nn.BatchNormXd` into `torch.nn.SyncBatchNorm`
     preserving values of parameters/buffers.
     the utility function also allows user to specify process_group value to all
     converted layers.

b. user wraps their model with
   `torch.distributed.parallel.DataParallelDistributed`, from this point, user
   should follow the general guidelines for DDP use guide

- Error checking

For use cases not supported, we error out:

1. Application launched without ddp:
   > import torch
   > sbn = torch.nn.SyncBatchNorm(10).cuda()
   > inp = torch.randn(5, 10, 3, 3).cuda()
   > sbn(inp) --> Error!
   > AttributeError: SyncBatchNorm is only supported within torch.nn.parallel.DistributedDataParallel

2. Application launched using DDP with multi-GPU per-process:
   > ddp_module = nn.parallel.DistributedDataParallel(module, device_ids=device_ids, output_device=args.local_rank)
   > ValueError: SyncBatchNorm is only supported for DDP with single GPU per process
Pull Request resolved: pytorch/pytorch#14267

Differential Revision: D14270035

Pulled By: ezyang

fbshipit-source-id: 4956d8fa565c32e9df5408d53719ff9f945f4d6d
@DrJimFan
Copy link

DrJimFan commented May 1, 2019

How does torch 1.1's SyncBN feature compare to Nvidia apex library?

@soumith
Copy link
Member

soumith commented May 1, 2019

it's similar in functionality, but likely better perf because it directly integrates with nn.BatchNorm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.