Skip to content

Commit

Permalink
*Add the option to use bf16 for gradients all-reduces.
Browse files Browse the repository at this point in the history
  *Add the option to skip gradients adjustment and scaling, the gradient normalization is expensive itself with extra all-gather overheads.

PiperOrigin-RevId: 307919356
  • Loading branch information
lingvo-bot authored and Copybara-Service committed Apr 22, 2020
1 parent c2e1e24 commit 6138c97
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 9 deletions.
2 changes: 2 additions & 0 deletions lingvo/core/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def Params(cls):
'operations. This avoids some race conditions.')
tp.Define('colocate_gradients_with_ops', True,
'If True, try colocating gradients with the corresponding op.')
tp.Define('scale_gradients', True,
'Whether to apply gradients adjustment and scaling.')
# LINT.ThenChange(learner.py)
p.Define('eval', hyperparams.Params(),
'Params to control how this task should be evaled.')
Expand Down
12 changes: 8 additions & 4 deletions lingvo/core/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def Params(cls):
'None: do not skip zero gradients; '
'"variable": skip if the entire variable gradients are almost zero; '
'"weight": skip if the individual weight gradients are almost zero.')
p.Define('scale_gradients', True,
'Whether to apply gradients adjustment and scaling.')
return p

@base_layer.initializer
Expand Down Expand Up @@ -229,10 +231,11 @@ def AdjustGradients(self,
if gradient_mask:
var_grads = py_utils.MaskGradients(var_grads, gradient_mask)

# Apply gradient clipping.
scaled_vars = self.ScaleGradients(
var_grads, gradient_adjuster=gradient_adjuster)
var_grads = scaled_vars.final_var_grads
# Scale gradients, e.g., gradient clipping.
if p.scale_gradients:
scaled_vars = self.ScaleGradients(
var_grads, gradient_adjuster=gradient_adjuster)
var_grads = scaled_vars.final_var_grads

# Histogram summary.
summary_utils.CollectVarHistogram(var_grads)
Expand Down Expand Up @@ -352,6 +355,7 @@ def _AddScalarSummary(self, key, value):
'clip_gradient_single_norm_to_value',
'colocate_gradients_with_ops',
'gate_gradients',
'scale_gradients',
'grad_aggregation_method',
'grad_norm_to_clip_to_zero',
'grad_norm_tracker',
Expand Down
14 changes: 12 additions & 2 deletions lingvo/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ class Base(base_layer.BaseLayer):
def Params(cls):
p = super(Base, cls).Params()
p.name = cls.__name__
p.Define(
'use_bf16_gradients_ar', False,
'Whether to use bfloat16 dtype for gradients all-reduce. '
'This applies to TPU only.')
return p

def GetOptimizer(self, lr):
Expand All @@ -46,6 +50,7 @@ def AddSummary(self, lr, optimizer, var_grad):

def ComputeGradients(self, loss, vmap, *args, **kwargs):
"""Allows subclasses control computation of gradients."""
kwargs['use_bf16_gradients_ar'] = self.params.use_bf16_gradients_ar
return py_utils.ComputeGradients(loss, vmap, *args, **kwargs)

def VarReuseForSlotVars(self):
Expand All @@ -68,8 +73,13 @@ def Apply(self, lr, var_grad):
optimizer = self.GetOptimizer(lr)

def _Apply():
return optimizer.apply_gradients(
[(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop')
if self.params.use_bf16_gradients_ar:
return optimizer.apply_gradients(
[(tf.cast(g, tf.float32), v) for (v, g) in var_grad.Flatten()],
name='meta_backprop')
else:
return optimizer.apply_gradients(
[(g, v) for (v, g) in var_grad.Flatten()], name='meta_backprop')

if not py_utils.use_resource_variables():
var_update_op = _Apply()
Expand Down
16 changes: 13 additions & 3 deletions lingvo/core/py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,7 +1950,8 @@ def _ComputeGradientsTpu(loss,
grad_aggregation_method,
colocate_gradients_with_ops,
gate_gradients,
skip_zero_gradients=None):
skip_zero_gradients=None,
use_bf16_gradients_ar=False):
"""Computes gradients for local loss across whole TPU cluster.
This implementation specializes for the case where weight params maybe used
Expand All @@ -1969,6 +1970,8 @@ def _ComputeGradientsTpu(loss,
with the original op.
gate_gradients: boolean, flag to be passed to tf.gradients.
skip_zero_gradients: whether to skip zero gradients during aggregation.
use_bf16_gradients_ar: Whether to use bfloat16 dtype for gradients
all-reduce.
Returns:
Gradients to be passed back.
Expand Down Expand Up @@ -2002,6 +2005,8 @@ def _ComputeGradientsTpu(loss,
if g is None:
aggregated_grads.append(None)
continue
if use_bf16_gradients_ar:
g = tf.cast(g, tf.bfloat16)
with tf.colocate_with(g):
if skip_zero_gradients is None:
# loss is already scaled by 1/shards.
Expand Down Expand Up @@ -2058,7 +2063,8 @@ def ComputeGradients(
colocate_gradients_with_ops=True,
gate_gradients=False,
compute_gradients_fn=None,
skip_zero_gradients=None):
skip_zero_gradients=None,
use_bf16_gradients_ar=False):
"""Computes gradients of variables in vmap w.r.t loss.
Args:
Expand All @@ -2085,6 +2091,8 @@ def ComputeGradients(
reduce_sum(abs(grads)) < 1e-8.
* `weight`: skip if the individual weight's gradients are almost zero:
abs(grad) < 1e-8.
use_bf16_gradients_ar: Whether to use bfloat16 dtype for gradients
all-reduce. This applies to TPU only.
Returns:
var_grad - a `.NestedMap` of VarGrad. You can view
Expand Down Expand Up @@ -2126,7 +2134,9 @@ def Needed(v):
# tpu vs non-tpu is slightly different.
if use_tpu():
take_grad = functools.partial(
_ComputeGradientsTpu, skip_zero_gradients=skip_zero_gradients)
_ComputeGradientsTpu,
skip_zero_gradients=skip_zero_gradients,
use_bf16_gradients_ar=use_bf16_gradients_ar)
else:
take_grad = ComputeGradientsSimple

Expand Down

0 comments on commit 6138c97

Please sign in to comment.