/
loss_scale_optimizer.py
412 lines (345 loc) · 17.4 KB
/
loss_scale_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the loss scaling optimizer class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.distribute import one_device_strategy
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.mixed_precision.experimental import loss_scale as keras_loss_scale_module
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util.tf_export import keras_export
class _UnwrapPreventer(object):
"""Wrapper that DistributionStrategy will not unwrap.
Typically, DistributionStrategy will unwrap values when going from a cross-
replica context to a replica context via `call_for_each_replica`. This class
is a wrapper that DistributionStrategy will not unwrap, so it can be used to
prevent it from unwrapping a value.
TODO(reedwm): Find/implement a better way of preventing values from being
unwrapped by DistributionStrategy
"""
def __init__(self, value):
self.value = value
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used. By
multiplying the loss, each intermediate gradient will have the same multiplier
applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, either through `minimize()` or `get_gradients()`. The loss scale is
updated via `LossScale.update()` whenever gradients are applied, either
through `minimize()` or `apply_gradients()`. For example:
>>> opt = tf.keras.optimizers.SGD(0.25)
>>> opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt,
... "dynamic")
>>> var = tf.Variable(1.)
>>> loss_fn = lambda: var ** 2
>>> # 'minimize' applies loss scaling to the loss and updates the loss sale.
>>> opt.minimize(loss_fn, var_list=var)
>>> var.numpy()
0.5
If a `tf.GradientTape` is used to compute gradients instead of
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
and gradients must be scaled manually. This can be done by calling
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
computing the gradients with `tf.GradientTape`. For example:
>>> with tf.GradientTape() as tape:
... loss = loss_fn()
... scaled_loss = opt.get_scaled_loss(loss)
>>> scaled_grad = tape.gradient(scaled_loss, var)
>>> (grad,) = opt.get_unscaled_gradients([scaled_grad])
>>> opt.apply_gradients([(grad, var)]) # Loss scale is updated here
>>> var.numpy()
0.25
"""
_HAS_AGGREGATE_GRAD = True
def __init__(self, optimizer, loss_scale):
"""Initializes this loss scale optimizer.
Args:
optimizer: The Optimizer instance to wrap.
loss_scale: The loss scale to scale the loss and gradients. This can
either be an int/float to use a fixed loss scale, the string "dynamic"
to use dynamic loss scaling, or an instance of a LossScale. The string
"dynamic" equivalent to passing `DynamicLossScale()`, and passing an
int/float is equivalent to passing a FixedLossScale with the given loss
scale.
"""
if not isinstance(optimizer, optimizer_v2.OptimizerV2):
raise ValueError('"optimizer" must be an instance of OptimizerV2, but '
'got: %s' % optimizer)
if optimizer.clipnorm is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipnorm. Optimizer %s has clipnorm '
'%s' % (optimizer, optimizer.clipnorm))
if optimizer.clipvalue is not None:
raise ValueError('LossScaleOptimizer does not support wrapping '
'optimizers with a clipvalue. Optimizer %s has '
'clipvalue %s' % (optimizer, optimizer.clipvalue))
self._raise_if_strategy_unsupported()
self.clipnorm = None
self.clipvalue = None
self._optimizer = optimizer
self._loss_scale = keras_loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None.')
for weight in loss_scale_module.get_loss_scale_weights(self._loss_scale):
# We cannot call `track_variable` in the LossScale class itself, because a
# file outside of Keras cannot depend on a Keras file. Calling it here
# instead is OK, because a variable only needs to be tracked if used with
# a Keras class, and the only way to use LossScale with a Keras class is
# through the LossScaleOptimizer.
backend.track_variable(weight)
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
# Needed because the superclass's __getattribute__ checks this.
self._hyper = {}
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
return self._loss_scale
def get_scaled_loss(self, loss):
"""Scales the loss by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to scale the loss before
passing the loss to `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_unscaled_gradients` should also be called.
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for
an example.
Args:
loss: The loss, which will be multiplied by the loss scale. Can either be
a tensor or a callable returning a tensor.
Returns:
`loss` multiplied by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
if callable(loss):
def new_loss():
loss_val = loss()
return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
return new_loss
else:
return loss * math_ops.cast(loss_scale, loss.dtype)
def get_unscaled_gradients(self, grads):
"""Unscales the gradients by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to unscale the gradients
after computing them with `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_scaled_loss` should also be called. See
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an
example.
Args:
grads: A list of tensors, each which will be divided by the loss scale.
Can have None values, which are ignored.
Returns:
A new list the same size as `grads`, where every non-None value in `grads`
is divided by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1. / loss_scale
return [g * math_ops.cast(loss_scale_reciprocal, g.dtype) if g is not None
else None for g in grads]
def _compute_gradients(self, loss, var_list, grad_loss=None):
loss = self.get_scaled_loss(loss)
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
unscaled_grads = self.get_unscaled_gradients(grads)
return list(zip(unscaled_grads, variables))
def get_gradients(self, loss, params):
loss = self.get_scaled_loss(loss)
grads = self._optimizer.get_gradients(loss, params)
return self.get_unscaled_gradients(grads)
def _create_all_weights(self, var_list):
self._optimizer._create_all_weights(var_list) # pylint: disable=protected-access
def apply_gradients(self,
grads_and_vars,
name=None,
experimental_aggregate_gradients=True):
if distribution_strategy_context.in_cross_replica_context():
raise ValueError('apply_gradients() must be called in a replica context.')
# We check for the strategy here despite already checking in the constructor
# as frequently the optimizer is created outside the strategy's scope.
self._raise_if_strategy_unsupported()
grads_and_vars = tuple(grads_and_vars)
return distribution_strategy_context.get_replica_context().merge_call(
self._apply_gradients_cross_replica,
args=(grads_and_vars, name, experimental_aggregate_gradients))
def _apply_gradients_cross_replica(self, distribution, grads_and_vars, name,
experimental_aggregate_gradients):
grads = [g for g, _ in grads_and_vars]
loss_scale_update_op, should_apply_grads = self._loss_scale.update(grads)
def apply_fn():
# We do not want DistributionStrategy to unwrap any MirroredVariables in
# grads_and_vars, because even in a replica context, the wrapped optimizer
# expects mirrored variables. So we wrap the variables with an
# _UnwrapPreventer, preventing DistributionStrategy from unwrapping the
# MirroredVariables.
wrapped_vars = _UnwrapPreventer([v for _, v in grads_and_vars])
return distribution.extended.call_for_each_replica(
self._apply_gradients,
args=(grads, wrapped_vars, name, experimental_aggregate_gradients))
# Note: We must call this cond() in a cross-replica context.
# DistributionStrategy does not support having a cond in a replica context
# with a branch that calls `merge_call`, and self._optimizer.apply_gradients
# calls `merge_call`.
maybe_apply_op = smart_cond.smart_cond(should_apply_grads,
apply_fn,
control_flow_ops.no_op)
return control_flow_ops.group(maybe_apply_op, loss_scale_update_op)
def _apply_gradients(self, grads, wrapped_vars, name,
experimental_aggregate_gradients):
# TODO(reedwm): This will raise a fairly cryptic error message if
# self._optimizer.apply_gradients does not take
# experimental_aggregate_gradients.
return self._optimizer.apply_gradients(
list(zip(grads, wrapped_vars.value)), name,
experimental_aggregate_gradients=experimental_aggregate_gradients)
def get_config(self):
serialized_optimizer = optimizers.serialize(self._optimizer)
serialized_loss_scale = keras_loss_scale_module.serialize(self._loss_scale)
return {
'optimizer': serialized_optimizer,
'loss_scale': serialized_loss_scale,
}
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy() # Make a copy, since we mutate config
config['optimizer'] = optimizers.deserialize(
config['optimizer'], custom_objects=custom_objects)
config['loss_scale'] = keras_loss_scale_module.deserialize(
config['loss_scale'], custom_objects=custom_objects)
return cls(**config)
def _raise_if_strategy_unsupported(self):
if not strategy_supports_loss_scaling():
strategy = distribution_strategy_context.get_strategy()
raise ValueError('Loss scaling is not supported with the '
'tf.distribute.Strategy: %s. Try using a different '
'Strategy, e.g. a MirroredStrategy' %
strategy.__class__.__name__)
# Delegations: We delegate most OptimizerV2 methods to the wrapped optimizer
# below.
@property
def iterations(self):
return self._optimizer.iterations
@iterations.setter
def iterations(self, variable):
self._optimizer.iterations = variable
def get_slot_names(self):
return self._optimizer.get_slot_names()
def variables(self):
return self._optimizer.variables()
@property
def weights(self):
return self._optimizer.weights
def get_weights(self):
return self._optimizer.get_weights()
def set_weights(self, weights):
return self._optimizer.set_weights(weights)
def _aggregate_gradients(self, grads_and_vars):
return self._optimizer._aggregate_gradients(grads_and_vars) # pylint: disable=protected-access
# For the most part, we only expose methods in the base OptimizerV2, not
# individual subclasses like Adam. However, although "learning_rate" and "lr"
# properties are not part of the base OptimizerV2 class, they are part of most
# subclasses, so we expose them here for convenience.
@property
def learning_rate(self):
return self._optimizer.learning_rate
@learning_rate.setter
def learning_rate(self, lr):
self._optimizer.learning_rate = lr
@property
def lr(self):
return self._optimizer.lr
@lr.setter
def lr(self, lr):
self._optimizer.lr = lr
def get_slot(self, var, slot_name):
# We cannot implement get_slot for the following reason: When saving a
# checkpoint, two optimizers cannot share slot variables. Since both the
# LossScaleOptimizer and the wrapped optimizer (self and self._optimizer
# respectively) are checkpointed, we cannot expose the wrapped optimizer's
# slots in the LossScaleOptimizer. Otherwise, a checkpoint would believe
# both optimizers share slot variables.
raise AttributeError(
'You cannot call get_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.')
def add_slot(self, var, slot_name, initializer='zeros'):
# We disallow adding a slot for consistency with `get_slot`.
raise AttributeError(
'You cannot call add_slot on a LossScaleOptimizer. This limitation '
'will be removed in the future.')
# We do not override some OptimizerV2 methods. For each, we describe why we do
# not delegate them to self._optimizer:
# * get_updates: get_updates() calls get_gradients(). Since we override
# get_gradients(), we cannot delegate get_updates() to self._optimizer,
# otherwise the overridden get_gradients() method would not be called.
# Luckily, get_updates() does not access any OptimizerV2 fields, so
# inheriting the OptimizerV2 version works fine.
# * minimize: We don't delegate for a similar as get_updates(): it calls
# both self._compute_gradients() and self.apply_gradients(), and both need
# to have the LossScaleOptimizer version called.
# TODO(reedwm): Maybe merge this class's functionality into OptimizerV2.
# TODO(reedwm): Maybe throw an error if mixed precision is used without this
# optimizer being used.
def strategy_supports_loss_scaling():
"""Returns True if the current Strategy supports loss scaling."""
if not distribution_strategy_context.has_strategy():
return True
strategy = distribution_strategy_context.get_strategy()
# Strategies are supported if either there is only one replica or if variables
# are replicated per device. Otherwise, the current model.fit() implementation
# and most custom training loops incorrectly unscale the gradients. Currently,
# gradients are unscaled once per compute replica, but they should be unscaled
# once per variable replica. When there is one variable replica for each
# compute replica, this works fine, but otherwise issues will occur.
# TODO(reedwm): Support all strategies.
return isinstance(strategy, (
collective_all_reduce_strategy.CollectiveAllReduceStrategy,
collective_all_reduce_strategy.CollectiveAllReduceStrategyV1,
one_device_strategy.OneDeviceStrategy,
one_device_strategy.OneDeviceStrategyV1,
mirrored_strategy.MirroredStrategy,
mirrored_strategy.MirroredStrategyV1,
))