Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch Normalization for Multi-GPU / Data Parallelism #7439

Closed
kiranvaidhya opened this issue Feb 11, 2017 · 35 comments
Closed

Batch Normalization for Multi-GPU / Data Parallelism #7439

kiranvaidhya opened this issue Feb 11, 2017 · 35 comments
Labels
stat:contribution welcome Status - Contributions welcome

Comments

@kiranvaidhya
Copy link

kiranvaidhya commented Feb 11, 2017

Where is the batch normalization implementation for Multi-GPU scenarios? How does one keep track of mean, variance, offset and scale in the context of the Multi-GPU example as given in the CIFAR-10 tutorial?

Why is the question on StackOverflow left unanswered for so long?

For all the beauty that it brings with Tensorboard etc.. , it's kinda appalling to see Tensorflow so far behind Torch in terms of its modeling capability. I'd be really glad if someone takes up responsibility and comes up with a decent Batch Normalization implementation for all cases. Even if it is already there, could anyone care enough to make a good documentation out of it?

There are so many issues pertaining to batch normalization with Tensorflow. It's important that you guys straighten this out as batch normalization enables super-fast convergence for very deep networks and it is REALLY important for modern day deep learning research.

PS: Please spare my outburst. I've been a Torch user for more than a year and I had very high hopes on Tensorflow.

@yaroslavvb
Copy link
Contributor

yaroslavvb commented Feb 11, 2017

How does Torch handle multi-GPU batch normalization? Batch normalization on multi-GPU batch incurs and extra performance penalty because statistics need to be communicated across all GPUs, so are some performance questions to consider in. You can aggregate statistics on CPU, aggregate them by going around in a ring along the lines of how Nvidia NCCL all-reduce does, or aggregate them by doing a tree reduction.

Also you can also do a "pseudo-batch normalization", by using existing batch norm layer to normalize GPU-sized batches, and then add batches together for a single "multi-GPU batch".

I suspect there are easier ways to handle normalization of huge batches that doesn't introduce the performance hit you would see with batch normalization, like weight normalization -- https://arxiv.org/pdf/1602.07868.pdf

@ppwwyyxx
Copy link
Contributor

@kvrd18 As pointed out above, there are just too many ways to implement batch norm across GPUs. TensorFlow now doesn't seem to provide a "default" way how it is implemented.

@yaroslavvb My understanding is that most frameworks (including caffe & torch) doesn't aggregate statistics across GPUs at all. Different GPUs maintain statistics independently and statistics from only one GPU are used at test time. The official inceptionv3 example in tensorflow/models also does something similar.
I've used the same strategy for quite a while and it's working fine. One catch is that in this case it's more important to shuffle the training data, otherwise the statistics on different GPUs are not i.i.d any more.

@shiyemin
Copy link

I have the same issue, too. I believe distributed batch normalization is very important for some problems like action recognition. It will be very useful if Tensorflow can provide an implementation.

Here is my trick for train BN on action recognition, which just rearranges the sample order so that every GPU can get reasonable mean and variance.

@yaroslavvb
Copy link
Contributor

@shiyemin are approaches mentioned in @ppwwyyxx reply acceptable for your needs? This may be a good material for community contribution. If Google adds multi-GPU batch-norm it may end up optimized for their custom interconnect hardware, rather than commodity hardware.

@shiyemin
Copy link

@ppwwyyxx @yaroslavvb As far as i know, shuffling the training data is not enough for tasks like action recognition. Because that one GPU can only handle 2 videos (each video have 16 frames), which is a too small batch size for calculating the mean and variance. That's the reason i shuffle the batch itself at each step. After go through the CNN, the batch will be rearranged to its original order to be fed into LSTM.

However, my trick will hurt the speed severely.

@kiranvaidhya
Copy link
Author

kiranvaidhya commented Feb 13, 2017

@yaroslavvb, in Torch, the weights updates for each module in a replica are accumulated and summed together on the first replica. Owing to Torch's modular code base, with BatchNormalization being a module, its internal parameters also undergo the same computational flow as described above. Also, Torch has used NCCL to enable fast inter-GPU communication.

To define a model in Torch, you would do this,

function makeConvNet()
  model = nn.Sequential()
  model:add(nn.SpatialConvolution(1,32,3,3))
  model:add(nn.SpatialBatchNormalization(32))
  model:add(nn.View(-1):setNumInputDims(3))
  return model
end

Here, nn.SpatialConvolution and nn.SpatialBatchNormalization are modules which have its own forward and backward passes. All you have to do to make it compatible with data parallelism is to invoke nn.DataParallelTable

-- CONSTRUCT MODEL:
conv_net = makeConvNet()  -- i.e. create nn.Sequential() and fill it
net = nn.DataParallelTable(1)  -- Split along first (batch) dimension
net:add(conv_net, {1, 2}) -- Use GPUs 1 and 2
-- TRAINING:
for i = 1, num_epochs do
  local output = net:forward(input)
  local err = criterion:forward(output, target)
  net:zeroGradParameters()
  local gradOutput = criterion:backward(output, target)
  local gradInput = net:backward(input, gradOutput)
  net:updateParameters(lr)
end

@ppwwyyxx
Copy link
Contributor

@kvrd18 What you described is the general case for most modules in torch. However for batch normalization, my best understanding is that torch by default doesn't synchronize the mean/variance among GPUs, but only the other two parameters (scaling and shifting).
Relevant issue here: torch/nn#1071

@kiranvaidhya
Copy link
Author

It's going to be painful to train fully convolutional networks on multiple GPUs that cannot afford to have huge batch sizes to alleviate the problems that might arise out of not synchronizing mean and variance among GPUs.

@yaroslavvb
Copy link
Contributor

It would be good experiment to make -- compare torch approach, vs. keeping variance on GPU0 vs keep variance on CPU. I suspect when your GPUs are p2p connected, keeping vars on GPU0 will be better. (ie, I found cifar multi-GPU example runs 15% faster when weights are pinned to GPU0)

@ppwwyyxx
Copy link
Contributor

@kvrd18 It could be an improvement to aggregate the statistics (before the actual normalization), instead of normalizing by each GPU's own statistics. This can avoid potential problems that the statistics of a small batch is too unstable. You can do this in tensorflow but this is going to be very expensive.

Maybe Batch Renormalization is a better option in this case. It shows a better performance on small batches.

@aselle
Copy link
Contributor

aselle commented Feb 14, 2017

This question might be better asked on StackOverflow since it is not a clear feature request yet. However, if we define this as a feature for simple easy to use multi-gpu batch normalization, that would be a great contribution. Marking this as contributions welcome as a result. Thanks!

@aselle aselle added the stat:contribution welcome Status - Contributions welcome label Feb 15, 2017
@kiranvaidhya
Copy link
Author

I've built a batch normalization layer for multi-GPU. It predicts well on the validation set only if the is_training is True. Not sure why though. Can someone help me with this?

def _variable_on_cpu(name, shape, initializer, trainable=True):

	with tf.device('/cpu:0'):
		dtype = tf.float32
		var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype, trainable=trainable)
	return var

def BatchNorm(inputs, is_training, decay = 0.9, epsilon=1e-3):

	scale = _variable_on_cpu('scale', inputs.get_shape()[-1], tf.constant_initializer(1.0))
	beta = _variable_on_cpu('beta', inputs.get_shape()[-1], tf.constant_initializer(0.0))
	pop_mean = _variable_on_cpu('mean', inputs.get_shape()[-1], tf.constant_initializer(0.0), trainable=False)
	pop_var = _variable_on_cpu('variance', inputs.get_shape()[-1], tf.constant_initializer(1.0), trainable=False)
	axis = list(range(len(inputs.get_shape())-1))

	def Train(inputs, pop_mean, pop_var, scale, beta):
		batch_mean, batch_var = tf.nn.moments(inputs,axis)
		train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
		train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
		with tf.control_dependencies([train_mean,train_var]):
			return tf.nn.batch_normalization(inputs, batch_mean, batch_var, beta, scale, epsilon)

	def Eval(inputs, pop_mean, pop_var, scale, beta):
		return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)

	return tf.cond(is_training, lambda: Train(inputs, pop_mean, pop_var, scale, beta),
		lambda: Eval(inputs, pop_mean, pop_var, scale, beta))

This is working well on multi-GPU / data parallelism as long as the module is in training mode.

@shiyemin
Copy link

@kvrd18 For a multi-GPU batch norm, we have to sync the batch_mean and batch_var, not just the moving_mean and moving_var, so that every GPU will get batch_mean and batch_var which are close to the global mean and variance.

@kiranvaidhya
Copy link
Author

@shiyemin Each layer in each GPU will have its own batch_mean and batch_var, if I'm correct because the input data is split across the batch dimension and fed to each GPU. I do not understand what you mean when you say we'll have to sync the batch_mean and batch_var. They're specific to the input batch and are computed by tf.nn.moments.

@kiranvaidhya
Copy link
Author

@shiyemin If I understood you right, I'd have to compute the moments by concatenating the input batches on the fly when the forward pass is being computed. That is going to be very expensive, computationally. Is this where the usage of NVIDIAs nccl is recommended?

@kiranvaidhya
Copy link
Author

The inception example as pointed in #7439 (comment) is a good enough solution for me now. Thanks, @ppwwyyxx.

@John1231983
Copy link

Hello, I am training the BatchNorm layer in multiple GPUs using tf.contrib.layers.batch_norm function. In the training phase, we have to collect moving_mean and moving_variance using the function

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 

However, I found that utilization of the function has some ways

1.Inside a loop function cifar10_main

with tf.device('/cpu:0'):
  update_ops=[]
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())
  with tf.control_dependencies(update_ops):
     self.train_op = tf.group(train_op_conv,variables_averages_op)

2.Outside a loop function cifar10_multi_gpu

with tf.device('/cpu:0'):
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      #Igore the line update_ops
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
  with tf.control_dependencies(update_ops):
     self.train_op = tf.group(train_op_conv,variables_averages_op)

3.Both inside and outside a loop function inception v3, cifar10

with tf.device('/cpu:0'):
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())            
  batchnorm_updates_op = tf.group(*update_ops)
  self.train_op = tf.group(train_op_conv, train_op_fc,variables_averages_op,batchnorm_updates_op)

What is the right way? In my opinion, it may be the third way

@holyseven
Copy link

I've found a simple way to implement distributed batch normalization in pure tensorflow, which I would like to share with you guys: batch norm across GPUs.

This may be interesting for video/action recognition, image segmentation and other domains where the batch size is very limited in a single GPU.

@twangnh
Copy link

twangnh commented Jun 29, 2018

@ppwwyyxx
Copy link
Contributor

@MrWanter The code you linked to has nothing to do with batchnorm.

@ppwwyyxx
Copy link
Contributor

For batch norm with multi-gpu statistics, the link given by @holyseven looks like a right implementation. Tensorpack also has such feature that works with its multigpu trainers.
For distributed batch norm, tensorpack has an implementation based on its HorovodTrainer but it would depend on horovod/horovod#331

@twangnh
Copy link

twangnh commented Jun 30, 2018

@holyseven's implementation seems split batch and construct each layer on each gpu one by one, but in distributed training with multiple workers would that be hard to integrate into existing functions like tf.train.replica_device_setter, also split whole batch across workers with this implementation seems infeasible?

@ppwwyyxx
Copy link
Contributor

do you know the purpose of variables_to_average=variables_to_average in the tf.train.SyncReplicasOptimizer?

It averages the trainable parameters and the moving averages among workers. It does not perform cross-gpu or cross-machine batch norm.

is it possible to modify for distributed batchnorm?

It's very far from that.

can you refer to some examples for the distributed batchnorm with Horovod

With a working horovod training code, adding an option sync_statistics='horovod' to BatchNorm is all you need to do. But for now its performance has a lot of room for improvement (horovod/horovod#318 (comment)).

@twangnh
Copy link

twangnh commented Jun 30, 2018

Thanks for the response, how about using your implementation of with tf.contrib.nccl? As In the thread, you mention your CGBN based on tf.contrib.nccl is 1380im/s, and without CGBN, it is 1556im/s, so the overhead is small.
So you mean the room for improvement is the remaining around 170im/s?

@ppwwyyxx
Copy link
Contributor

I meant there might be room for horovod, not much for nccl.

@twangnh
Copy link

twangnh commented Jun 30, 2018

I see, have you tried using nccl for synchronous distributed batchnorm, the communication would be heavier since batch statistics need to aggregate per layer over workers through internet connection

@ppwwyyxx
Copy link
Contributor

nccl does not support it.

@twangnh
Copy link

twangnh commented Jun 30, 2018

seems nccl2 support inter-node all_reduce operation

@ppwwyyxx
Copy link
Contributor

seems there is no way to use it from tensorflow

@michaelklachko
Copy link

What is the difference between @holyseven's implementation and the inception example?

@jkyl
Copy link
Contributor

jkyl commented Jun 8, 2019

Here is a custom Keras layer which implements train-phase cross-replica batch normalization under a MirroredStrategy. If anyone finds this useful or wants to submit a PR, I could use some help implementing the prediction phase (i.e. moving mean/variance).

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from tensorflow.python.keras.engine.base_layer import InputSpec, Layer
import tensorflow as tf

class SyncBatchNorm(Layer):
  """Cross-replica batch normalization layer"""
  def __init__(
      self,
      center=True,
      scale=False,
      trainable=True,
      name=None,
      **kwargs
    ):
    super(SyncBatchNorm, self).__init__(
      name=name, trainable=trainable, **kwargs)
    self.axis = -1
    self.center = center
    self.scale = scale
    self.supports_masking = True
    self.epsilon = 1e-3

  def build(self, input_shape):
    dim = input_shape[self.axis]
    if dim is None:
      raise ValueError(
        'Axis ' + str(self.axis) + ' of '
        'input tensor should have a defined dimension '
        'but the layer received an input with shape ' +
        str(input_shape) + '.'
      )
    self.input_spec = InputSpec(
      ndim=len(input_shape),
      axes={self.axis: dim}
    )
    shape = (dim,)
    if self.scale:
      self.gamma = self.add_weight(
        shape=shape,
        name='gamma',
        initializer='ones',
      )
    else:
      self.gamma = None
    if self.center:
      self.beta = self.add_weight(
        shape=shape,
        name='beta',
        initializer='zeros',
      )
    else:
      self.beta = None
    self.built = True

  def call(self, x, training=None):
    ctx = tf.distribute.get_replica_context()
    n = ctx.num_replicas_in_sync
    mean, mean_sq = ctx.all_reduce(
      tf.distribute.ReduceOp.SUM,
      [tf.reduce_mean(x, axis=0) / n,
       tf.reduce_mean(x**2, axis=0) / n]
    )
    variance = mean_sq - mean ** 2
    return tf.nn.batch_normalization(
      x,
      mean,
      variance,
      self.beta,
      self.gamma,
      self.epsilon)

  def compute_output_shape(self, input_shape):
    return input_shape

  def get_config(self):
    return {
      'axis': self.axis,
      'epsilon': self.epsilon,
      'center': self.center,
      'scale': self.scale,
    }

@jiangxuetao0823
Copy link

Hello, I am training the BatchNorm layer in multiple GPUs using tf.contrib.layers.batch_norm function. In the training phase, we have to collect moving_mean and moving_variance using the function

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 

However, I found that utilization of the function has some ways

1.Inside a loop function cifar10_main

with tf.device('/cpu:0'):
  update_ops=[]
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      update_ops.extend(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())
  with tf.control_dependencies(update_ops):
     self.train_op = tf.group(train_op_conv,variables_averages_op)

2.Outside a loop function cifar10_multi_gpu

with tf.device('/cpu:0'):
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      #Igore the line update_ops
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)  
  with tf.control_dependencies(update_ops):
     self.train_op = tf.group(train_op_conv,variables_averages_op)

3.Both inside and outside a loop function inception v3, cifar10

with tf.device('/cpu:0'):
  with tf.variable_scope(tf.get_variable_scope()):
     for i in range(self.conf.num_gpus):
        with tf.device('/gpu:%d' % i):
	   with tf.name_scope('device_%d' % i):
	      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
  variable_averages = tf.train.ExponentialMovingAverage(self.conf.MOVING_AVERAGE_DECAY, global_step)
  variables_averages_op = variable_averages.apply(tf.trainable_variables())            
  batchnorm_updates_op = tf.group(*update_ops)
  self.train_op = tf.group(train_op_conv, train_op_fc,variables_averages_op,batchnorm_updates_op)

What is the right way? In my opinion, it may be the third way

have you got the right way? thanks

@John1231983
Copy link

I never used tensorflow again because it is hard to research. I killed it without finding the answer. I guess to move to pytorch

@mrgloom
Copy link

mrgloom commented Nov 20, 2019

Does run under horovod average batch norm statistics across gpus?

@byronyi
Copy link
Contributor

byronyi commented Feb 3, 2020

Does run under horovod average batch norm statistics across gpus?

The answer is no AFAIK, at least for Horovod 0.19.0.

FYI, TF just added Sync BN: adf7690

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome
Projects
None yet
Development

No branches or pull requests