/
loss_scale_optimizer.py
254 lines (214 loc) · 10.5 KB
/
loss_scale_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@deprecation.deprecated_endpoints(
'train.experimental.MixedPrecisionLossScaleOptimizer')
@tf_export(v1=['mixed_precision.MixedPrecisionLossScaleOptimizer',
'train.experimental.MixedPrecisionLossScaleOptimizer'])
class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used for mixed
precision training. By multiplying the loss, each intermediate gradient will
have the same multiplier applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, such as through `minimize()`.
"""
def __init__(self, opt, loss_scale):
if not isinstance(opt, optimizer.Optimizer):
raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
type(opt))
self._optimizer = opt
use_locking = opt._use_locking # pylint: disable=protected-access
name = opt.get_name()
super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
self._loss_scale = loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None')
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
def _doing_dynamic_loss_scaling(self):
"""Check if `_loss_scale` dynamically manages the loss scale."""
return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
def compute_gradients(self,
loss,
var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
"""Compute gradients of `loss` for the variables in `var_list`.
This adjusts the dynamic range of the gradient evaluation by scaling up
the `loss` value. The gradient values are then scaled back down by the
reciprocal of the loss scale. This is useful in reduced precision training
where small gradient values would otherwise underflow the representable
range.
Args:
loss: A Tensor containing the value to minimize or a callable taking no
arguments which returns the value to minimize. When eager execution is
enabled it must be a callable.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph under
the key `GraphKeys.TRAINABLE_VARIABLES`.
gate_gradients: How to gate the computation of gradients. Can be
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
Returns:
A list of (gradient, variable) pairs. Variable is always present, but
gradient can be `None`.
"""
loss = self._scale_loss(loss)
grads_and_vars = self._optimizer.compute_gradients(
loss=loss,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
unscaled_grads = self._unscale_grads(grads)
return list(zip(unscaled_grads, variables))
def _scale_loss(self, loss):
loss_scale = self._loss_scale()
if callable(loss):
def new_loss():
loss_val = loss()
return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
return new_loss
else:
return loss * math_ops.cast(loss_scale, loss.dtype)
def _unscale_grads(self, grads):
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1 / loss_scale
return [
None if g is None else self._scale_grad(g, loss_scale_reciprocal)
for g in grads
]
def _scale_grad(self, grad, loss_scale_reciprocal):
if isinstance(grad, ops.IndexedSlices):
grad_vals = grad.values * loss_scale_reciprocal
return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
return grad * loss_scale_reciprocal
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
conditionally applies gradients if all gradient values are finite.
Otherwise no update is performed (nor is `global_step` incremented).
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that conditionally applies the specified gradients. If
`global_step` was not None, that operation also increments `global_step`.
Raises:
RuntimeError: If you should use `_distributed_apply()` instead.
"""
if distribution_strategy_context.in_cross_replica_context():
raise ValueError('apply_gradients() must be called in a replica context.')
if not self._doing_dynamic_loss_scaling():
return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
replica_context = distribution_strategy_context.get_replica_context()
grads_and_vars = tuple(grads_and_vars)
# TODO(nluehr) cleanup GraphKeys.TRAIN_OP
return replica_context.merge_call(
self._distributed_apply, args=(grads_and_vars, global_step, name))
def _distributed_apply(self,
distribution,
grads_and_vars,
global_step=None,
name=None):
"""A version of `apply_gradients` for cross replica context.
When users are in a cross replica strategy, they must call this rather than
`apply_gradients()`.
Args:
distribution: a `DistributionStrategy` object.
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()` and then aggregated across replicas.
global_step: Optional (mirrored) `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the name passed
to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients across all
replicas. If `global_step` was not None, that operation also
increments `global_step`
"""
name = name if name is not None else self.get_name()
grads = [g for g, _ in grads_and_vars]
loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
def apply_fn():
return self._apply_gradients(distribution, grads_and_vars, global_step,
name + '-wrapped')
maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
control_flow_ops.no_op)
return control_flow_ops.group(
maybe_apply_op, loss_scale_update_op, name=name)
def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
"""Unconditionally apply gradients in cross replica context."""
update_ops = distribution.extended.call_for_each_replica(
self._optimizer.apply_gradients,
args=(grads_and_vars, global_step, name))
return distribution.group(update_ops)
def _apply_sparse(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _apply_dense(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_sparse(self, grad, handle, indices):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_dense(self, grad, handle):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def variables(self):
"""Returns the variables of the Optimizer."""
return (self._optimizer.variables() +
list(self._loss_scale._weights.values())) # pylint: disable=protected-access