Skip to content

Commit

Permalink
rewrite allreduce and avoid bug in TF's nccl
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx committed Jul 20, 2020
1 parent dbc0b36 commit 6151e04
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 76 deletions.
2 changes: 1 addition & 1 deletion examples/FasterRCNN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,5 @@
trainer = HorovodTrainer(average=False)
else:
# nccl mode appears faster than cpu mode
trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, average=False, mode='nccl')
trainer = SyncMultiGPUTrainerReplicated(cfg.TRAIN.NUM_GPUS, average=False)
launch_train_with_config(traincfg, trainer)
11 changes: 7 additions & 4 deletions tensorpack/graph_builder/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..utils import logger
from ..utils.argtools import memoized
from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable
from .utils import OverrideCachingDevice, split_grad_list, allreduce_grads_naive, override_to_local_variable

__all__ = []

Expand Down Expand Up @@ -123,7 +123,9 @@ def build(self, get_grad_fn, get_opt_fn):
DataParallelBuilder._check_grad_list(grad_list)

with tf.device(self.param_server_device):
grads = aggregate_grads(grad_list, colocation=False)
all_grads, all_vars = split_grad_list(grad_list)
all_grads = allreduce_grads_naive(all_grads)
grads = [(g, v) for g, v in zip(all_grads, all_vars[0])]
opt = get_opt_fn()
train_op = opt.apply_gradients(grads, name='train_op')
train_op = self._add_sync_queues_and_barrier('all_workers_sync_barrier', [train_op])
Expand Down Expand Up @@ -285,8 +287,9 @@ def build(self, get_grad_fn, get_opt_fn):
use_vs=[True] * len(self.towers)) # open vs at each tower
DataParallelBuilder._check_grad_list(grad_list)

avg_grads = aggregate_grads(
grad_list, colocation=False, devices=self.raw_devices)
all_grads, all_vars = split_grad_list(grad_list)
avg_grads = allreduce_grads_naive(all_grads, devices=self.raw_devices) # N
avg_grads = [(g, v) for g, v in zip(all_grads, all_vars[0])]
with tf.device(self.param_server_device):
ps_var_grads = DistributedReplicatedBuilder._apply_shadow_vars(avg_grads)
var_update_ops = self._apply_gradients_and_copy(
Expand Down
66 changes: 33 additions & 33 deletions tensorpack/graph_builder/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from ..utils import logger
from ..utils.develop import HIDE_DOC
from .utils import (
GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical,
GradientPacker, LeastLoadedDeviceSetter,
aggregate_grads_colocate, allreduce_grads_naive,
allreduce_grads, allreduce_grads_hierarchical,
merge_grad_list, override_to_local_variable, split_grad_list)

__all__ = ["DataParallelBuilder"]
Expand Down Expand Up @@ -173,12 +175,13 @@ def build(self, grad_list, get_opt_fn):
assert len(grad_list) == len(self.towers)
DataParallelBuilder._check_grad_list(grad_list)

# debug tower performance (without update):
# debug tower performance:
# ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
# self.train_op = tf.group(*ops)
# return

self.grads = aggregate_grads(grad_list, colocation=True)
self.grads = aggregate_grads_colocate(grad_list)
# debug tower performance:
# grads = grad_list[0]

opt = get_opt_fn()
Expand All @@ -204,13 +207,11 @@ class SyncMultiGPUReplicatedBuilder(DataParallelBuilder):
def __init__(self, towers, average, mode):
super(SyncMultiGPUReplicatedBuilder, self).__init__(towers)
self._average = average
assert mode in ['nccl', 'cpu', 'hierarchical'], mode
if get_tf_version_tuple() >= (2, 0) and mode == 'cpu':
mode = 'nccl' # cpu mode causes the entire model to get located on cpu
assert mode in ['nccl', 'cpu', 'hierarchical', 'gpu', 'collective'], mode
self._mode = mode

if self._mode == 'hierarchical' and len(towers) != 8:
logger.warn("mode='hierarchical' require >= 8 GPUs. Fallback to mode='nccl'.")
logger.warn("mode='hierarchical' require 8 GPUs. Fallback to mode='nccl'.")
self._mode = 'nccl'

def call_for_each_tower(self, tower_fn):
Expand Down Expand Up @@ -257,39 +258,38 @@ def build(self, grad_list, get_opt_fn):
valid_for_nccl = all(k in dtypes_nccl_supported for k in dtypes)
if self._mode == 'nccl' and not valid_for_nccl:
logger.warn("Cannot use mode='nccl' because some gradients have unsupported types. Fallback to mode='cpu'")
self._mode = 'cpu'
self._mode = 'gpu'

if self._mode in ['nccl', 'hierarchical']:
all_grads, all_vars = split_grad_list(grad_list)
all_grads, all_vars = split_grad_list(grad_list)

def do_allreduce(all_grads):
# use allreduce from tf-benchmarks
# from .batch_allreduce import AllReduceSpecAlgorithm
# algo = AllReduceSpecAlgorithm('nccl', list(range(8)), 0, 10)
# all_grads, warmup_ops = algo.batch_all_reduce(all_grads, 1, True, False)
# print("WARMUP OPS", warmup_ops)

if self._mode == 'nccl':
all_grads = allreduce_grads(all_grads, average=self._average) # #gpu x #param
if self._mode in ['nccl', 'collective']:
# #gpu x #param
all_grads = allreduce_grads(all_grads, average=self._average, mode=self._mode)
elif self._mode == 'hierarchical':
all_grads = allreduce_grads_hierarchical(all_grads, raw_devices, average=self._average)
else:
packer = GradientPacker(len(raw_devices))
succ = packer.compute_strategy(all_grads[0])
if succ:
packed_grads = packer.pack_all(all_grads, raw_devices)
packed_grads_aggr = allreduce_grads_hierarchical(
packed_grads, raw_devices, average=self._average)
all_grads = packer.unpack_all(packed_grads_aggr, raw_devices)
else:
all_grads = allreduce_grads_hierarchical(all_grads, raw_devices, average=self._average)

self.grads = merge_grad_list(all_grads, all_vars)
elif self._mode == 'cpu':
agg_grad_and_vars = aggregate_grads(
grad_list, colocation=False,
devices=['/cpu:0'], average=self._average) # #param x 2
self.grads = [] # #gpu x #param x 2
for grad_and_vars in grad_list: # grad_and_vars: #paramx2
# take v from each tower, and g from average.
self.grads.append(
[(g, v) for (_, v), (g, _) in zip(grad_and_vars, agg_grad_and_vars)])
devices = ['/cpu:0'] if self._mode == 'cpu' else raw_devices
all_grads = allreduce_grads_naive(all_grads, devices=devices, average=self._average)
all_grads = [all_grads] * len(self.towers)
return all_grads

use_packer = self._mode in ['hierarchical']
if use_packer:
packer = GradientPacker(len(raw_devices))
use_packer = packer.compute_strategy(all_grads[0]) # may fail to pack
if use_packer:
all_grads = packer.pack_all(all_grads, raw_devices)
all_grads = do_allreduce(all_grads) # all the work happens here
if use_packer:
all_grads = packer.unpack_all(all_grads, raw_devices)

self.grads = merge_grad_list(all_grads, all_vars)

train_ops = []
opt = get_opt_fn()
Expand Down
126 changes: 91 additions & 35 deletions tensorpack/graph_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
from contextlib import contextmanager
import tensorflow as tf
import threading

from ..compat import tfv1
from ..tfutils.common import get_tf_version_tuple
Expand All @@ -13,7 +14,7 @@
from ..utils import logger
from ..utils.argtools import call_only_once

__all__ = ["LeastLoadedDeviceSetter", "allreduce_grads", "aggregate_grads"]
__all__ = ["LeastLoadedDeviceSetter", "allreduce_grads"]


"""
Expand All @@ -33,6 +34,19 @@ def _replace_global_by_local(kwargs):
kwargs['collections'] = list(collections)


_module_lock = threading.Lock()
_shared_cnt_counter = 0


def _get_shared_cnt():
global _shared_cnt_counter

with _module_lock:
val = _shared_cnt_counter
_shared_cnt_counter += 1
return val


@contextmanager
def override_to_local_variable(enable=True):
"""
Expand Down Expand Up @@ -84,17 +98,18 @@ def canonicalize(name): # tensorflow/tensorflow#11484
if op.type not in ['Variable', 'VariableV2']:
return canonicalize(self.worker_device)

device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1))
device_name = self.place_with_balance(op)
return canonicalize(device_name)

def place_with_balance(self, op):
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
if var_size is None:
logger.warn("[LeastLoadedDeviceSetter] Shape of variable {} is not fully defined!".format(op.name))
var_size = 0

self.ps_sizes[device_index] += var_size

return canonicalize(device_name)
return device_name

def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
Expand Down Expand Up @@ -130,28 +145,42 @@ def merge_grad_list(all_grads, all_vars):


@under_name_scope('AllReduceGrads')
def allreduce_grads(all_grads, average):
def allreduce_grads(all_grads, average, mode="nccl"):
"""
All-reduce average the gradients among K devices. Results are broadcasted to all devices.
Args:
all_grads (K x N): List of list of gradients. N is the number of variables.
average (bool): average gradients or not.
mode (str): "nccl", "collective"
Returns:
K x N: same as input, but each grad is replaced by the average over K devices.
"""
assert mode in ["nccl", "collective"], mode

if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl # deprecated
else:
from tensorflow.python.ops import nccl_ops as nccl
nr_tower = len(all_grads)
if nr_tower == 1:
return all_grads
new_all_grads = [] # N x K
for grads in zip(*all_grads):
summed = nccl.all_sum(grads)
# k grads
if mode == "nccl":
if get_tf_version_tuple() <= (1, 12):
from tensorflow.contrib import nccl # deprecated
else:
from tensorflow.python.ops import nccl_ops as nccl
summed = nccl.all_sum(grads)
else:
from tensorflow.python.ops import collective_ops
summed = []
shared_cnt = _get_shared_cnt()
for t in grads:
with tf.device(t.device):
t = collective_ops.all_reduce(
t, len(grads), shared_cnt, shared_cnt + 100,
'Add', 'Id')
summed.append(t)

grads_for_devices = [] # K
for g in summed:
Expand Down Expand Up @@ -229,28 +258,57 @@ def allreduce_grads_hierarchical(all_grads, devices, average=False):
return agg_all_grads


@under_name_scope('AggregateGrads')
def aggregate_grads(all_grads,
colocation=False,
devices=None,
average=True):
@under_name_scope('AggregateGradsColocate')
def aggregate_grads_colocate(all_grads, average=True):
"""
Average the gradients.
Aggregate the gradients. The aggregation is colocated with the variable.
Args:
all_grads (K x N x 2): A list of K lists. Each of the list is a list of N (grad, var) tuples.
The variables have to be shared across the K lists.
average (bool): do average or sum
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K.
"""
nr_tower = len(all_grads)
if nr_tower == 1:
return all_grads[0]

def aggregate(grads):
if average:
return tf.multiply(tf.add_n(grads), 1.0 / nr_tower)
else:
return tf.add_n(grads)

ret = []
for idx, grad_and_vars in enumerate(zip(*all_grads)):
# Ngpu * 2
v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars]
with tf.device(v.device): # colocate summed grad with var
grad = aggregate(grads)
ret.append((grad, v))
return ret


@under_name_scope('AllReduceNaive')
def allreduce_grads_naive(all_grads, devices=None, average=True):
"""
AllReduce the gradients with raw ops (instead of collective ops).
Args:
all_grads (K x N): A list of K lists. Each of the list is a list of N grad tuples.
The variables have to be the same across the K lists.
colocation (bool): colocate gradient averaging on the device of the variable.
devices (list[str]): assign the averaging to these device in
round-robin. Cannot be used together with ``colocation``.
average (bool): do average or sum
Returns:
(N x 2): A list of N (grad, var) tuples, where grad is averaged or summed over K.
list[Tensor]: list of grads where each grad is averaged or summed over K.
"""
assert not (devices is not None and colocation)
if devices is not None:
assert isinstance(devices, list), devices
# device_setter = LeastLoadedDeviceSetter(None, devices)

nr_tower = len(all_grads)
if nr_tower == 1:
Expand All @@ -262,26 +320,22 @@ def aggregate(grads):
else:
return tf.add_n(grads)

ret = []
for idx, grad_and_vars in enumerate(zip(*all_grads)):
# Ngpu * 2
v = grad_and_vars[0][1]
grads = [g for (g, _) in grad_and_vars]
grads_ret = [] # N(rev) grads
# reverse so the device placement makes the last part of model more balance?
all_grads_rev = [x[::-1] for x in all_grads] # K x N(rev)

if colocation:
with tf.device(v.device): # colocate summed grad with var
grad = aggregate(grads)
elif devices is None:
for idx, grads in enumerate(zip(*all_grads_rev)):
# grads: K tensors
if devices is None:
grad = aggregate(grads)
else:
# dev = device_setter.place_with_balance(v.op)
dev = devices[idx % len(devices)]
with tf.device(dev):
grad = aggregate(grads)
ret.append((grad, v))
return ret


average_grads = aggregate_grads
grads_ret.append(grad)
grads_ret = grads_ret[::-1]
return grads_ret


# https://github.com/tensorflow/benchmarks/blob/48cbef14a592e02a14beee8e9aef3ad22cadaed1/scripts/tf_cnn_benchmarks/variable_mgr_util.py#L140-L166
Expand Down Expand Up @@ -319,6 +373,8 @@ def __call__(self, getter, *args, **kwargs):
return var


# TODO pack at variable boundary, so that the concat does not have to wait for all
# grads to be ready
class GradientPacker(object):
"""
Concat gradients together to optimize transfer.
Expand Down
3 changes: 3 additions & 0 deletions tensorpack/train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def main_loop(self, steps_per_epoch, starting_epoch, max_epoch):
except KeyboardInterrupt:
logger.info("Detected Ctrl-C and exiting main loop.")
raise
except Exception:
logger.error("Training failed at global_step=", self.loop.global_step)
raise
finally:
self._callbacks.after_train()
self.hooked_sess.close()
Expand Down
2 changes: 1 addition & 1 deletion tensorpack/train/model_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_optimizer(self):
"""
ret = self.optimizer()
assert isinstance(ret, tfv1.train.Optimizer), \
"ModelDesc.optimizer() must return a tf.train.Optimizer! Got {} instead.".format(str(ret))
"ModelDesc.optimizer() must return an instance of tf.train.Optimizer! Got {} instead.".format(str(ret))
return ret

def optimizer(self):
Expand Down

0 comments on commit 6151e04

Please sign in to comment.