-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtpu_embedding_v2_utils.py
1295 lines (1117 loc) · 53 KB
/
tpu_embedding_v2_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Companion classes for mid level API for TPU Embeddings in TF2."""
import abc
import math
import typing
from typing import Any, Dict, Callable, Iterable, List, Optional, Text, Tuple, TypeVar, Union
from absl import logging
from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.framework import device_spec
from tensorflow.python.framework import ops
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.types import core
from tensorflow.python.util.tf_export import tf_export
TableVariable = TypeVar("TableVariable", sharded_variable.ShardedVariable,
tf_variables.Variable)
SlotVarCreationFnType = Callable[
[TableVariable, List[Text], List[init_ops_v2.Initializer]],
Dict[Text, TableVariable]]
ClipValueType = Union[Tuple[float, float], float]
class _Optimizer(metaclass=abc.ABCMeta):
"""Base class for all optimizers, with common parameters."""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]],
use_gradient_accumulation: bool,
clip_weight_min: Optional[float],
clip_weight_max: Optional[float],
weight_decay_factor: Optional[float],
multiply_weight_decay_factor_by_learning_rate: bool,
clipvalue: Optional[ClipValueType] = None,
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
low_dimensional_packing_status: bool = False,
):
self.learning_rate = learning_rate
self.use_gradient_accumulation = use_gradient_accumulation
self.clip_weight_min = clip_weight_min
self.clip_weight_max = clip_weight_max
if not use_gradient_accumulation and clipvalue is not None:
raise ValueError(
f"When `use_gradient_accumulation` is False, gradient clipping "
f"cannot be used and `clipvalue` should be left as None. "
f"Received value {clipvalue} for argument `clipvalue`.")
if clipvalue is None:
clipvalue = (None, None)
elif not isinstance(clipvalue, tuple):
clipvalue = (-1. * clipvalue, clipvalue)
self.clip_gradient_min, self.clip_gradient_max = clipvalue
self.weight_decay_factor = weight_decay_factor
self.multiply_weight_decay_factor_by_learning_rate = (
multiply_weight_decay_factor_by_learning_rate)
if (slot_variable_creation_fn is not None and
not callable(slot_variable_creation_fn)):
raise ValueError(
f"Argument `slot_variable_creation_fn` must be either None or a "
f"callable. Received: {slot_variable_creation_fn}")
self.slot_variable_creation_fn = slot_variable_creation_fn
self.low_dimensional_packing_status = low_dimensional_packing_status
@abc.abstractmethod
def _slot_names(self) -> List[Text]:
"""Returns the name of all the slot variables.
This does not include the 'parameters' variable and these names must match
the names of the slots variables as used in the corresponding
`tpu_ops.load_tpu_embedding_*` ops.
"""
raise NotImplementedError
@abc.abstractmethod
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
"""Returns initializers for slot variables.
This returns a parallel list to self._slot_names().
"""
raise NotImplementedError
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
"""Sets the optimizer fields in the OptimizationParameters."""
if self.use_gradient_accumulation:
parameters.gradient_accumulation_status = (
optimization_parameters_pb2.GradientAccumulationStatus.ENABLED)
else:
parameters.gradient_accumulation_status = (
optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
if self.clip_weight_min is not None:
parameters.clipping_limits.lower.value = self.clip_weight_min
if self.clip_weight_max is not None:
parameters.clipping_limits.upper.value = self.clip_weight_max
if self.clip_gradient_min is not None:
parameters.gradient_clipping_limits.lower.value = self.clip_gradient_min
if self.clip_gradient_max is not None:
parameters.gradient_clipping_limits.upper.value = self.clip_gradient_max
if self.weight_decay_factor:
parameters.weight_decay_factor = self.weight_decay_factor
if self.multiply_weight_decay_factor_by_learning_rate:
parameters.multiply_weight_decay_factor_by_learning_rate = True
parameters.low_dimensional_packing_status = (
self.low_dimensional_packing_status
)
@abc.abstractmethod
def _load(self) -> Callable[..., ops.Operation]:
"""Returns the load function for the optimizer."""
raise NotImplementedError
@abc.abstractmethod
def _retrieve(self) -> Callable[..., core.Tensor]:
"""Returns the retrieve function for the optimizer."""
raise NotImplementedError
def _create_slots(
self, table: "TableConfig",
variable_creator: Callable[[Text, init_ops_v2.Initializer],
tf_variables.Variable]
) -> Dict[Text, tf_variables.Variable]:
"""Creates slot variables for table.
Args:
table: The table variable to create slots for.
variable_creator: A function which creates variables. Takes parameters
'name', 'initializer'.
Returns:
A dict of variables, keyed by self._slot_names().
"""
if self.slot_variable_creation_fn is not None:
return self.slot_variable_creation_fn(table, self._slot_names(),
self._slot_initializers())
else:
slots = {}
for slot, initializer in zip(self._slot_names(),
self._slot_initializers()):
slots[slot] = variable_creator(slot, initializer)
return slots
def __eq__(self, other: Any) -> Union[Any, bool]:
if isinstance(other, self.__class__):
return all([
attr1 == attr2
for attr1, attr2 in zip(self.__dict__.items(), other.__dict__.items())
])
else:
return False
def __hash__(self) -> int:
return hash(tuple(self.__dict__.items()))
@tf_export("tpu.experimental.embedding.SGD")
class SGD(_Optimizer):
"""Optimization parameters for stochastic gradient descent for TPU embeddings.
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
argument to set the global optimizer and its parameters:
```
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
...
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
optimizer parameter to set a table specific optimizer. This will override the
optimizer and parameters for global embedding optimizer defined above:
```
table_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...,
optimizer=tf.tpu.experimental.embedding.SGD(0.2))
table_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = (
tf.tpu.experimental.embedding.FeatureConfig(
table=table_one),
tf.tpu.experimental.embedding.FeatureConfig(
table=table_two))
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
In the above example, the first feature will be looked up in a table that has
a learning rate of 0.2 while the second feature will be looked up in a table
that has a learning rate of 0.1.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]] = 0.01,
use_gradient_accumulation: bool = True,
clip_weight_min: Optional[float] = None,
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: bool = None,
clipvalue: Optional[ClipValueType] = None,
low_dimensional_packing_status: bool = False,
):
"""Optimization parameters for stochastic gradient descent.
Args:
learning_rate: The learning rate. It should be a floating point value or a
callable taking no arguments for a dynamic learning rate.
use_gradient_accumulation: setting this to `False` makes embedding
gradients calculation less accurate but faster.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed. Weights are decayed by multiplying the weight
by this factor each step.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
clipvalue: Controls clipping of the gradient. Set to either a single
positive scalar value to get clipping or a tiple of scalar values (min,
max) to set a separate maximum or minimum. If one of the two entries is
None, then there will be no clipping that direction. Note if this is
set, you may see a decrease in performance as gradient accumulation
will be enabled (it is normally off for SGD as it has no affect on
accuracy). See
'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for more
information on gradient accumulation and its impact on tpu embeddings.
low_dimensional_packing_status: Status of the low-dimensional embedding
packing optimization controls whether to optimize the packing of
1-dimensional, 2-dimensional, and 4-dimensional embedding tables in
memory.
"""
super().__init__(
learning_rate,
use_gradient_accumulation,
clip_weight_min,
clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate,
clipvalue,
None,
low_dimensional_packing_status,
)
def _slot_names(self) -> List[Text]:
return []
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
return []
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
super()._set_optimization_parameters(parameters)
parameters.stochastic_gradient_descent.SetInParent()
def _load(self) -> Callable[..., ops.Operation]:
return tpu_ops.load_tpu_embedding_stochastic_gradient_descent_parameters
def _retrieve(self) -> Callable[..., core.Tensor]:
return tpu_ops.retrieve_tpu_embedding_stochastic_gradient_descent_parameters
@tf_export("tpu.experimental.embedding.Adagrad")
class Adagrad(_Optimizer):
"""Optimization parameters for Adagrad with TPU embeddings.
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
argument to set the global optimizer and its parameters:
```python
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
...
optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
```
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
optimizer parameter to set a table specific optimizer. This will override the
optimizer and parameters for global embedding optimizer defined above:
```python
table_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...,
optimizer=tf.tpu.experimental.embedding.Adagrad(0.2))
table_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = (
tf.tpu.experimental.embedding.FeatureConfig(
table=table_one),
tf.tpu.experimental.embedding.FeatureConfig(
table=table_two))
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.Adagrad(0.1))
```
In the above example, the first feature will be looked up in a table that has
a learning rate of 0.2 while the second feature will be looked up in a table
that has a learning rate of 0.1.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]] = 0.001,
initial_accumulator_value: float = 0.1,
use_gradient_accumulation: bool = True,
clip_weight_min: Optional[float] = None,
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: bool = None,
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
clipvalue: Optional[ClipValueType] = None,
low_dimensional_packing_status: bool = False,
):
"""Optimization parameters for Adagrad.
Args:
learning_rate: The learning rate. It should be a floating point value or a
callable taking no arguments for a dynamic learning rate.
initial_accumulator_value: initial accumulator for Adagrad.
use_gradient_accumulation: setting this to `False` makes embedding
gradients calculation less accurate but faster.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
slot_variable_creation_fn: If you wish do directly control the creation of
the slot variables, set this to a callable taking three parameters: a
table variable, a list of slot names to create for it, and a list of
initializers. This function should return a dict with the slot names as
keys and the created variables as values with types matching the table
variable. When set to None (the default), uses the built-in variable
creation.
clipvalue: Controls clipping of the gradient. Set to either a single
positive scalar value to get clipping or a tuple of scalar values (min,
max) to set a separate maximum or minimum. If one of the two entries is
None, then there will be no clipping that direction.
low_dimensional_packing_status: Status of the low-dimensional embedding
packing optimization controls whether to optimize the packing of
1-dimensional, 2-dimensional, and 4-dimensional embedding tables in
memory.
"""
super().__init__(
learning_rate,
use_gradient_accumulation,
clip_weight_min,
clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate,
clipvalue,
slot_variable_creation_fn,
low_dimensional_packing_status,
)
if initial_accumulator_value <= 0:
raise ValueError(
f"Argument `initial_accumulator_value` must be a positive float. "
f"Received: {initial_accumulator_value}")
self.initial_accumulator_value = initial_accumulator_value
def _slot_names(self) -> List[Text]:
return ["accumulators"]
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
return [init_ops_v2.Constant(self.initial_accumulator_value)]
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
super()._set_optimization_parameters(parameters)
parameters.adagrad.SetInParent()
def _load(self) -> Callable[..., ops.Operation]:
return tpu_ops.load_tpu_embedding_adagrad_parameters
def _retrieve(self) -> Callable[..., core.Tensor]:
return tpu_ops.retrieve_tpu_embedding_adagrad_parameters
@tf_export("tpu.experimental.embedding.AdagradMomentum")
class AdagradMomentum(_Optimizer):
"""Optimization parameters for Adagrad + Momentum with TPU embeddings.
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
argument to set the global optimizer and its parameters:
```python
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
...
optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.1))
```
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
optimizer parameter to set a table specific optimizer. This will override the
optimizer and parameters for global embedding optimizer defined above:
```python
table_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...,
optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.2))
table_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = (
tf.tpu.experimental.embedding.FeatureConfig(
table=table_one),
tf.tpu.experimental.embedding.FeatureConfig(
table=table_two))
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.AdagradMomentum(0.1))
```
In the above example, the first feature will be looked up in a table that has
a learning rate of 0.2 while the second feature will be looked up in a table
that has a learning rate of 0.1.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]] = 0.001,
momentum: float = 0.0,
use_nesterov: bool = False,
exponent: float = 2,
beta2: float = 1,
epsilon: float = 1e-10,
use_gradient_accumulation: bool = True,
clip_weight_min: Optional[float] = None,
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: bool = None,
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
clipvalue: Optional[ClipValueType] = None,
low_dimensional_packing_status: bool = False,
):
"""Optimization parameters for Adagrad + Momentum.
Args:
learning_rate: The learning rate. It should be a floating point value or a
callable taking no arguments for a dynamic learning rate.
momentum: Moving average parameter for the momentum accumulator.
use_nesterov: Whether to use the Nesterov variant of momentum. See
Sutskever et al., 2013.
exponent: Exponent for the Adagrad accumulator.
beta2: Moving average parameter for the Adagrad accumulator.
epsilon: initial accumulator for Adagrad accumulator.
use_gradient_accumulation: setting this to `False` makes embedding
gradients calculation less accurate but faster.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
slot_variable_creation_fn: If you wish do directly control the creation of
the slot variables, set this to a callable taking three parameters: a
table variable, a list of slot names to create for it, and a list of
initializers. This function should return a dict with the slot names as
keys and the created variables as values with types matching the table
variable. When set to None (the default), uses the built-in variable
creation.
clipvalue: Controls clipping of the gradient. Set to either a single
positive scalar value to get clipping or a tuple of scalar values (min,
max) to set a separate maximum or minimum. If one of the two entries is
None, then there will be no clipping that direction.
low_dimensional_packing_status: Status of the low-dimensional embedding
packing optimization controls whether to optimize the packing of
1-dimensional, 2-dimensional, and 4-dimensional embedding tables in
memory.
"""
super().__init__(
learning_rate,
use_gradient_accumulation,
clip_weight_min,
clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate,
clipvalue,
slot_variable_creation_fn,
low_dimensional_packing_status,
)
if epsilon <= 0:
raise ValueError("Adagrad momentum: epsilon must be positive")
if exponent <= 0:
raise ValueError("Adagrad momentum: Precondition exponent must >0")
self.momentum = momentum
self.use_nesterov = use_nesterov
self.exponent = exponent
self.beta2 = beta2
self.epsilon = epsilon
def _slot_names(self) -> List[Text]:
return ["accumulators", "momenta"]
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
return [init_ops_v2.Constant(), init_ops_v2.Constant()]
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
super()._set_optimization_parameters(parameters)
parameters.adagrad_momentum.SetInParent()
parameters.adagrad_momentum.momentum = self.momentum
parameters.adagrad_momentum.use_nesterov = self.use_nesterov
parameters.adagrad_momentum.exponent = self.exponent
parameters.adagrad_momentum.beta2 = self.beta2
parameters.adagrad_momentum.epsilon = self.epsilon
def _load(self) -> Callable[..., ops.Operation]:
return tpu_ops.load_tpu_embedding_adagrad_momentum_parameters
def _retrieve(self) -> Callable[..., core.Tensor]:
return tpu_ops.retrieve_tpu_embedding_adagrad_momentum_parameters
@tf_export("tpu.experimental.embedding.FTRL")
class FTRL(_Optimizer):
"""Optimization parameters for FTRL with TPU embeddings.
See Algorithm 1 of this
[paper](https://research.google.com/pubs/archive/41159.pdf).
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
argument to set the global optimizer and its parameters:
```python
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
...
optimizer=tf.tpu.experimental.embedding.FTRL(0.1))
```
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
optimizer parameter to set a table specific optimizer. This will override the
optimizer and parameters for global embedding optimizer defined above:
```python
table_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...,
optimizer=tf.tpu.experimental.embedding.FTRL(0.2))
table_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = (
tf.tpu.experimental.embedding.FeatureConfig(
table=table_one),
tf.tpu.experimental.embedding.FeatureConfig(
table=table_two))
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.FTRL(0.1))
```
In the above example, the first feature will be looked up in a table that has
a learning rate of 0.2 while the second feature will be looked up in a table
that has a learning rate of 0.1.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]] = 0.001,
learning_rate_power: float = -0.5,
l1_regularization_strength: float = 0.0,
l2_regularization_strength: float = 0.0,
beta: float = 0.0,
initial_accumulator_value: float = 0.1,
use_gradient_accumulation: bool = True,
clip_weight_min: Optional[float] = None,
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: bool = None,
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
clipvalue: Optional[ClipValueType] = None,
multiply_linear_by_learning_rate: bool = False,
allow_zero_accumulator: bool = False,
low_dimensional_packing_status: bool = False,
):
"""Optimization parameters for Adagrad.
Args:
learning_rate: The learning rate. It should be a floating point value or a
callable taking no arguments for a dynamic learning rate.
learning_rate_power: A float value, must be less or equal to zero.
Controls how the learning rate decreases during training. Use zero for a
fixed learning rate.
l1_regularization_strength: A float value, must be greater than or equal
to zero.
l2_regularization_strength: A float value, must be greater than or equal
to zero.
beta: A float value, representing the beta value from the paper.
initial_accumulator_value: The starting value for accumulators. Only zero
or positive values are allowed.
use_gradient_accumulation: setting this to `False` makes embedding
gradients calculation less accurate but faster.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
slot_variable_creation_fn: If you wish do directly control the creation of
the slot variables, set this to a callable taking three parameters: a
table variable, a list of slot names to create for it, and a list of
initializers. This function should return a dict with the slot names as
keys and the created variables as values with types matching the table
variable. When set to None (the default), uses the built-in variable
creation.
clipvalue: Controls clipping of the gradient. Set to either a single
positive scalar value to get clipping or a tuple of scalar values (min,
max) to set a separate maximum or minimum. If one of the two entries is
None, then there will be no clipping that direction.
multiply_linear_by_learning_rate: If set to True, a modified formula is
used for FTRL that treats the "linear" accumulator as being
pre-multiplied by the learning rate (i.e., the accumulator named
"linear" actually stores "linear * learning_rate"). Other than
checkpoint compatibility, this is mathematically equivalent for a static
learning rate; for a dynamic learning rate, it is nearly the same as
long as the learning rate does not change quickly. The benefit of this
is that the modified formula handles zero and near-zero learning rates
without producing NaNs, improving flexibility for learning rate ramp-up.
allow_zero_accumulator: If set to True, changes some internal formulas to
allow zero and near-zero accumulator values at the cost of some
performance; this only needs to be set if you are using an initial
accumulator value of zero, which is uncommon.
low_dimensional_packing_status: Status of the low-dimensional embedding
packing optimization controls whether to optimize the packing of
1-dimensional, 2-dimensional, and 4-dimensional embedding tables in
memory.
"""
super().__init__(
learning_rate,
use_gradient_accumulation,
clip_weight_min,
clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate,
clipvalue,
slot_variable_creation_fn,
low_dimensional_packing_status,
)
if initial_accumulator_value <= 0:
raise ValueError(
f"Argument `initial_accumulator_value` must be a positive float. "
f"Received: {initial_accumulator_value}")
self.initial_accumulator_value = initial_accumulator_value
self.learning_rate_power = learning_rate_power
self.l1_regularization_strength = l1_regularization_strength
self.l2_regularization_strength = l2_regularization_strength
self.beta = beta
self.multiply_linear_by_learning_rate = multiply_linear_by_learning_rate
self.allow_zero_accumulator = allow_zero_accumulator
def _slot_names(self) -> List[Text]:
return ["accumulators", "linears"]
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
return [
init_ops_v2.Constant(self.initial_accumulator_value),
init_ops_v2.Constant()
]
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
super()._set_optimization_parameters(parameters)
ftrl = parameters.ftrl
ftrl.l1 = self.l1_regularization_strength
ftrl.l2 = self.l2_regularization_strength
ftrl.lr_power = self.learning_rate_power
ftrl.beta = self.beta
ftrl.multiply_linear_by_lr = self.multiply_linear_by_learning_rate
ftrl.allow_zero_accumulator = self.allow_zero_accumulator
def _load(self) -> Callable[..., ops.Operation]:
return tpu_ops.load_tpu_embedding_ftrl_parameters
def _retrieve(self) -> Callable[..., core.Tensor]:
return tpu_ops.retrieve_tpu_embedding_ftrl_parameters
@tf_export("tpu.experimental.embedding.Adam")
class Adam(_Optimizer):
"""Optimization parameters for Adam with TPU embeddings.
Pass this to `tf.tpu.experimental.embedding.TPUEmbedding` via the `optimizer`
argument to set the global optimizer and its parameters:
NOTE: By default this optimizer is lazy, i.e. it will not apply the gradient
update of zero to rows that were not looked up. You can change this behavior
by setting `lazy_adam` to `False`.
```python
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
...
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
```
This can also be used in a `tf.tpu.experimental.embedding.TableConfig` as the
optimizer parameter to set a table specific optimizer. This will override the
optimizer and parameters for global embedding optimizer defined above:
```python
table_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...,
optimizer=tf.tpu.experimental.embedding.Adam(0.2))
table_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = (
tf.tpu.experimental.embedding.FeatureConfig(
table=table_one),
tf.tpu.experimental.embedding.FeatureConfig(
table=table_two))
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
```
In the above example, the first feature will be looked up in a table that has
a learning rate of 0.2 while the second feature will be looked up in a table
that has a learning rate of 0.1.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
"""
def __init__(
self,
learning_rate: Union[float, Callable[[], float]] = 0.001,
beta_1: float = 0.9,
beta_2: float = 0.999,
epsilon: float = 1e-07,
lazy_adam: bool = True,
sum_inside_sqrt: bool = True,
use_gradient_accumulation: bool = True,
clip_weight_min: Optional[float] = None,
clip_weight_max: Optional[float] = None,
weight_decay_factor: Optional[float] = None,
multiply_weight_decay_factor_by_learning_rate: bool = None,
slot_variable_creation_fn: Optional[SlotVarCreationFnType] = None,
clipvalue: Optional[ClipValueType] = None,
low_dimensional_packing_status: bool = False,
):
"""Optimization parameters for Adam.
See 'tensorflow/core/protobuf/tpu/optimization_parameters.proto' for a
complete description of these parameters and their impacts on the optimizer
algorithm.
Args:
learning_rate: The learning rate. It should be a floating point value or a
callable taking no arguments for a dynamic learning rate.
beta_1: A float value. The exponential decay rate for the 1st moment
estimates.
beta_2: A float value. The exponential decay rate for the 2nd moment
estimates.
epsilon: A small constant for numerical stability.
lazy_adam: Use lazy Adam instead of Adam. Lazy Adam trains faster.
sum_inside_sqrt: When this is true, the Adam update formula is changed
from `m / (sqrt(v) + epsilon)` to `m / sqrt(v + epsilon**2)`. This
option improves the performance of TPU training and is not expected to
harm model quality.
use_gradient_accumulation: Setting this to `False` makes embedding
gradients calculation less accurate but faster.
clip_weight_min: the minimum value to clip by; None means -infinity.
clip_weight_max: the maximum value to clip by; None means +infinity.
weight_decay_factor: amount of weight decay to apply; None means that the
weights are not decayed.
multiply_weight_decay_factor_by_learning_rate: if true,
`weight_decay_factor` is multiplied by the current learning rate.
slot_variable_creation_fn: If you wish do directly control the creation of
the slot variables, set this to a callable taking three parameters: a
table variable, a list of slot names to create for it, and a list of
initializers. This function should return a dict with the slot names as
keys and the created variables as values with types matching the table
variable. When set to None (the default), uses the built-in variable
creation.
clipvalue: Controls clipping of the gradient. Set to either a single
positive scalar value to get clipping or a tiple of scalar values (min,
max) to set a separate maximum or minimum. If one of the two entries is
None, then there will be no clipping that direction.
low_dimensional_packing_status: Status of the low-dimensional embedding
packing optimization controls whether to optimize the packing of
1-dimensional, 2-dimensional, and 4-dimensional embedding tables in
memory.
"""
super(Adam, self).__init__(
learning_rate,
use_gradient_accumulation,
clip_weight_min,
clip_weight_max,
weight_decay_factor,
multiply_weight_decay_factor_by_learning_rate,
clipvalue,
slot_variable_creation_fn,
low_dimensional_packing_status,
)
if beta_1 < 0. or beta_1 >= 1.:
raise ValueError(
f"Argument `beta_1` must be >= 0 and < 1. Received: {beta_1}.")
if beta_2 < 0. or beta_2 >= 1.:
raise ValueError(
f"Argument `beta_2` must be >= 0 and < 1. Received: {beta_1}.")
if epsilon <= 0.:
raise ValueError("epsilon must be positive; got {}.".format(epsilon))
if not use_gradient_accumulation and not lazy_adam:
raise ValueError(
"When disabling lazy Adam (`lazy_adam=False`), "
"gradient accumulation must be used. "
"Set `use_gradient_accumulation` to False.")
self.beta_1 = beta_1
self.beta_2 = beta_2
self.epsilon = epsilon
self.lazy_adam = lazy_adam
self.sum_inside_sqrt = sum_inside_sqrt
def _slot_names(self) -> List[Text]:
return ["momenta", "velocities"]
def _slot_initializers(self) -> List[init_ops_v2.Initializer]:
return [init_ops_v2.Constant(), init_ops_v2.Constant()]
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
super(Adam, self)._set_optimization_parameters(parameters)
parameters.adam.beta1 = self.beta_1
parameters.adam.beta2 = self.beta_2
parameters.adam.epsilon = self.epsilon
parameters.adam.use_non_lazy_adam = not self.lazy_adam
parameters.adam.use_sum_inside_sqrt = self.sum_inside_sqrt
def _load(self) -> Callable[..., ops.Operation]:
return tpu_ops.load_tpu_embedding_adam_parameters
def _retrieve(self) -> Callable[..., core.Tensor]:
return tpu_ops.retrieve_tpu_embedding_adam_parameters
@tf_export("tpu.experimental.embedding.QuantizationConfig")
class QuantizationConfig:
"""Settings for simulated quantization of the tpu embedding table.
When simulated quantization is enabled, the results of the embedding lookup
are clipped and quantized according to the settings here before the combiner
is applied.
For example, to quantize `input` the following is done:
```python
if input < lower
input = lower
if input > upper
input = upper
quantum = (upper - lower) / (num_buckets - 1)
input = math.floor((input - lower) / quantum + 0.5) * quantium + lower
```
See tensorflow/core/protobuf/tpu/optimization_parameters.proto for more
details.
NOTE: This does not change the storage type of the embedding table, that will
continue to be float32 as will the saved variable in the checkpoint. You will
have to manually quantize the variable (typically with the same algorithm and
settings as above) manually.
"""
def __init__(self, num_buckets: int, lower: float, upper: float):
"""Simulated quantizaiton configuration.
Args:
num_buckets: The number of quantization buckets, must be atleast 2.
lower: The lower bound for the quantization range.
upper: The upper bound for the quantization range.
Returns:
`QuantizationConfig`.
Raises:
ValueError: if `num_buckets` is less than 2.
"""
if num_buckets < 2:
raise ValueError(f"num_buckets is {num_buckets}, must be at least 2 for "
f"simulated quantization.")
self.num_buckets = num_buckets
self.lower = lower
self.upper = upper
def _set_optimization_parameters(
self, parameters: optimization_parameters_pb2.OptimizationParameters):
parameters.simulated_quantization.enabled = True
parameters.simulated_quantization.num_buckets = self.num_buckets
parameters.simulated_quantization.clipping_limits.lower.value = self.lower
parameters.simulated_quantization.clipping_limits.upper.value = self.upper
def __repr__(self):
return ("QuantizationConfig(num_buckets={num_buckets!r}, lower={lower!r}, "
"upper={upper!r})".format(
num_buckets=self.num_buckets,
lower=self.lower,
upper=self.upper))
@tf_export("tpu.experimental.embedding.TableConfig")
class TableConfig:
"""Configuration data for one embedding table.
This class holds the configuration data for a single embedding table. It is
used as the `table` parameter of a
`tf.tpu.experimental.embedding.FeatureConfig`. Multiple
`tf.tpu.experimental.embedding.FeatureConfig` objects can use the same
`tf.tpu.experimental.embedding.TableConfig` object. In this case a shared
table will be created for those feature lookups.
```python
table_config_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
table_config_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = {
'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_one),
'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_one),
'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_two)}
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
batch_size=...
optimizer=tf.tpu.experimental.embedding.Adam(0.1))
```
The above configuration has 2 tables, and three features. The first two