-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtpu_embedding_v2.py
1676 lines (1450 loc) · 73.6 KB
/
tpu_embedding_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Mid level API for TPU Embeddings."""
import functools
from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from absl import logging
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2
from tensorflow.python.distribute import device_util
from tensorflow.python.distribute import distribute_utils
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import sharded_variable
from tensorflow.python.distribute import tpu_strategy
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework.tensor_shape import TensorShape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.saved_model import registration
from tensorflow.python.saved_model import save_context
from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu.ops import tpu_ops
from tensorflow.python.trackable import autotrackable
from tensorflow.python.trackable import base
from tensorflow.python.types import internal as internal_types
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
_HOOK_KEY = "TPUEmbedding_saveable"
_NAME_KEY = "_tpu_embedding_layer"
class TPUEmbeddingVariable(sharded_variable.ShardedVariableMixin):
"""A ShardedVariable class for TPU."""
@property
def _in_graph_mode(self):
return self.variables[0]._in_graph_mode # pylint: disable=protected-access
def _add_key_attr(op, name):
op._set_attr(_NAME_KEY, attr_value_pb2.AttrValue(s=compat.as_bytes(name))) # pylint: disable=protected-access
@tf_export("tpu.experimental.embedding.TPUEmbedding")
class TPUEmbedding(autotrackable.AutoTrackable):
"""The TPUEmbedding mid level API.
NOTE: When instantiated under a TPUStrategy, this class can only be created
once per call to `tf.tpu.experimental.initialize_tpu_system`. If you wish to
re-initialize the embedding engine you must re-initialize the tpu as well.
Doing this will clear any variables from TPU, so ensure you have checkpointed
before you do this. If a further instances of the class are needed,
set the `initialize_tpu_embedding` argument to `False`.
This class can be used to support training large embeddings on TPU. When
creating an instance of this class, you must specify the complete set of
tables and features you expect to lookup in those tables. See the
documentation of `tf.tpu.experimental.embedding.TableConfig` and
`tf.tpu.experimental.embedding.FeatureConfig` for more details on the complete
set of options. We will cover the basic usage here.
NOTE: multiple `FeatureConfig` objects can use the same `TableConfig` object,
allowing different features to share the same table:
```python
table_config_one = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
table_config_two = tf.tpu.experimental.embedding.TableConfig(
vocabulary_size=...,
dim=...)
feature_config = {
'feature_one': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_one),
'feature_two': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_one),
'feature_three': tf.tpu.experimental.embedding.FeatureConfig(
table=table_config_two)}
```
There are two modes under which the `TPUEmbedding` class can used. This
depends on if the class was created under a `TPUStrategy` scope or not.
Under `TPUStrategy`, we allow access to the method `enqueue`, `dequeue` and
`apply_gradients`. We will show examples below of how to use these to train
and evaluate your model. Under CPU, we only access to the `embedding_tables`
property which allow access to the embedding tables so that you can use them
to run model evaluation/prediction on CPU.
First lets look at the `TPUStrategy` mode. Initial setup looks like:
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
```
When creating a distributed dataset that is to be passed to the enqueue
operation a special input option must be specified:
```python
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
```
Different feature inputs can have different shapes. For dense and sparse
tensor, rank 2 and above is supported. For ragged tensor, although only rank 2
is supported, you can specify the output shape to be rank 2 and above. The
output shape specified in the FeatureConfig has the first priority. The input
shape passed in build method has second priority and the input shapes
auto detected from input feature has the lowest priority. The latter two will
be converted to output shapes by omitting the last dimension. If the lower
priority one has output shapes which don't match the former one. A ValueError
will be raised. Only when the former one has undefined output shapes, the
latter one can override.
NOTE: All batches passed to the layer can have different input shapes. But
these input shapes need to match with the output shapes set by either
`FeatureConfig` or build method except for ragged tensor. Only 2D
ragged tensor with output shape set to higher dimensions is allowed as
long as the total number of elements matches. All subsequent calls must have
the same input shapes. In the event that the input shapes cannot be
automatically determined by the enqueue method, you must call
the build method with the input shapes or provide output shapes in the
`FeatureConfig` to initialize the layer.
To use this API on TPU you should use a custom training loop. Below is an
example of a training and evaluation step:
```python
@tf.function
def training_step(dataset_iterator, num_steps):
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
model_output = model(activations)
loss = ... # some function of labels and model_output
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
# Insert your model gradient and optimizer application here
for _ in tf.range(num_steps):
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features, ))
@tf.function
def evaluation_step(dataset_iterator, num_steps):
def tpu_step(tpu_features):
activations = embedding.dequeue()
model_output = model(activations)
# Insert your evaluation code here.
for _ in tf.range(num_steps):
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=False)
strategy.run(tpu_step, args=(tpu_features, ))
```
NOTE: The calls to `enqueue` have `training` set to `True` when
`embedding.apply_gradients` is used and set to `False` when
`embedding.apply_gradients` is not present in the function. If you don't
follow this pattern you may cause an error to be raised or the tpu may
deadlock.
In the above examples, we assume that the user has a dataset which returns
a tuple where the first element of the tuple matches the structure of what
was passed as the `feature_config` argument to the object initializer. Also we
utilize `tf.range` to get a `tf.while_loop` in order to increase performance.
When checkpointing your model, you should include your
`tf.tpu.experimental.embedding.TPUEmbedding` object in the checkpoint. It is a
trackable object and saving it will save the embedding tables and their
optimizer slot variables:
```python
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.save(...)
```
On CPU, only the `embedding_table` property is usable. This will allow you to
restore a checkpoint to the object and have access to the table variables:
```python
model = model_fn(...)
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=feature_config,
optimizer=tf.tpu.experimental.embedding.SGD(0.1))
checkpoint = tf.train.Checkpoint(model=model, embedding=embedding)
checkpoint.restore(...)
tables = embedding.embedding_tables
```
You can now use table in functions like `tf.nn.embedding_lookup` to perform
your embedding lookup and pass to your model.
"""
def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer], # pylint:disable=protected-access
pipeline_execution_with_tensor_core: bool = False):
"""Creates the TPUEmbedding mid level API object.
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(
feature_config=tf.tpu.experimental.embedding.FeatureConfig(
table=tf.tpu.experimental.embedding.TableConfig(
dim=...,
vocabulary_size=...)))
```
Args:
feature_config: A nested structure of
`tf.tpu.experimental.embedding.FeatureConfig` configs.
optimizer: An instance of one of `tf.tpu.experimental.embedding.SGD`,
`tf.tpu.experimental.embedding.Adagrad` or
`tf.tpu.experimental.embedding.Adam`. When not created under
TPUStrategy may be set to None to avoid the creation of the optimizer
slot variables, useful for optimizing memory consumption when exporting
the model for serving where slot variables aren't needed.
pipeline_execution_with_tensor_core: If True, the TPU embedding
computations will overlap with the TensorCore computations (and hence
will be one step old). Set to True for improved performance.
Raises:
ValueError: If optimizer is not one of tf.tpu.experimental.embedding.(SGD,
Adam or Adagrad) or None when created under a TPUStrategy.
"""
self._strategy = distribution_strategy_context.get_strategy()
self._using_tpu = isinstance(self._strategy, (tpu_strategy.TPUStrategy,
tpu_strategy.TPUStrategyV2))
self._pipeline_execution_with_tensor_core = (
pipeline_execution_with_tensor_core)
self._feature_config = feature_config
self._output_shapes = []
for feature in nest.flatten(feature_config):
self._output_shapes.append(feature.output_shape)
# The TPU embedding ops are slightly inconsistent with how they refer to
# tables:
# * The enqueue op takes a parallel list of tensors for input, one of those
# is the table id for the feature which matches the integer index of the
# table in the proto created by _create_config_proto().
# * The recv_tpu_embedding_activations op emits lookups per table in the
# order from the config proto.
# * The send_tpu_embedding_gradients expects input tensors to be per table
# in the same order as the config proto.
# * Per optimizer load and retrieve ops are specified per table and take the
# table name rather than the table id.
# Thus we must fix a common order to tables and ensure they have unique
# names.
# Set table order here to the order of the first occurence of the table in a
# feature provided by the user. The order of this struct must be fixed
# to provide the user with deterministic behavior over multiple
# instantiations.
self._table_config = []
for feature in nest.flatten(feature_config):
if feature.table not in self._table_config:
self._table_config.append(feature.table)
# Ensure tables have unique names. Also error check the optimizer as we
# specifically don't do that in the TableConfig class to allow high level
# APIs that are built on this to use strings/other classes to represent
# optimizers (before they are passed to this class).
table_names = []
for i, table in enumerate(self._table_config):
if table.optimizer is None:
# TODO(bfontain) Should we allow some sort of optimizer merging here?
table.optimizer = optimizer
if ((table.optimizer is not None or self._using_tpu) and
not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access
raise ValueError("{} is an unsupported optimizer class. Please pass an "
"instance of one of the optimizer classes under "
"tf.tpu.experimental.embedding.".format(
type(table.optimizer)))
if table.name is None:
table.name = "table_{}".format(i)
if table.name in table_names:
raise ValueError("Tables must have a unique name. "
f"Multiple tables with name {table.name} found.")
table_names.append(table.name)
if self._using_tpu:
# Extract a list of callable learning rates also in fixed order. Each
# table in the config proto will get an index into this list, and we will
# pass this list in the same order after evaluation to the
# send_tpu_embedding_gradients op.
self._dynamic_learning_rates = []
for table in self._table_config:
if (callable(table.optimizer.learning_rate) and
table.optimizer.learning_rate not in self._dynamic_learning_rates):
self._dynamic_learning_rates.append(table.optimizer.learning_rate)
# We need to list of host devices for the load/retrieve operations.
self._hosts = tpu_embedding_v2_utils.get_list_of_hosts(self._strategy)
self._built = False
self._verify_output_shapes_on_enqueue = True
def build(self, per_replica_input_shapes=None, per_replica_batch_size=None): # pylint:disable=g-bare-generic
"""Create the underlying variables and initializes the TPU for embeddings.
This method creates the underlying variables (including slot variables). If
created under a TPUStrategy, this will also initialize the TPU for
embeddings.
This function will automatically get called by enqueue, which will try to
determine your output shapes. If this fails, you must manually
call this method before you call enqueue.
Args:
per_replica_input_shapes: A nested structure of The per replica input
shapes that matches the structure of the feature config. The input
shapes should be the same as the input shape of the feature (except for
ragged tensor) Note that it is fixed and the same per replica input
shapes must be used for both training and evaluation. If you want to
calculate this from the global input shapes, you can use
`num_replicas_in_sync` property of your strategy object. May be set to
None if not created under a TPUStrategy.
per_replica_batch_size: (Deprecated) The per replica batch size that you
intend to use. Note that is fixed and the same batch size must be used
for both training and evaluation. If you want to calculate this from the
global batch size, you can use `num_replicas_in_sync` property of your
strategy object. May be set to None if not created under a TPUStrategy.
Raises:
ValueError: If per_replica_input_shapes is inconsistent with the output
shapes stored in the feature config or the output shapes get from the
input shapes are not fully defined.
RuntimeError: If tpu embedding is already initialized on TPU.
"""
if self._built:
return
if self._using_tpu:
# If the tpu embedding is already initialized on TPU, raise runtime error.
# Below logic is not added in `initialize_system_for_tpu_embedding`
# because doing exception control flow in graph mode is difficult.
if tpu_ops.is_tpu_embedding_initialized():
raise RuntimeError(
"TPU is already initialized for embeddings. This may be caused by "
"using multiple TPUEmbedding instances in a TPU scope which is "
"unsupported")
self._get_and_update_output_shapes_from_input(per_replica_input_shapes,
per_replica_batch_size)
self._config_proto = self._create_config_proto()
logging.info("Initializing TPU Embedding engine.")
tpu_embedding_v2_utils.log_tpu_embedding_configuration(self._config_proto)
@def_function.function
def load_config():
tpu.initialize_system_for_tpu_embedding(self._config_proto)
load_config()
logging.info("Done initializing TPU Embedding engine.")
# Create and load variables and slot variables into the TPU.
# Note that this is a dict of dicts. Keys to the first dict are table names.
# We would prefer to use TableConfigs, but then these variables won't be
# properly tracked by the tracking API.
self._variables = self._create_variables_and_slots()
self._built = True
# This is internally conditioned self._built and self._using_tpu
self._load_variables()
def _maybe_build(self,
output_shapes: Optional[Union[List[int], Iterable]] = None): # pylint:disable=g-bare-generic
if not self._built:
# This can be called while tracing a function, so we wrap the
# initialization code with init_scope so it runs eagerly, this means that
# it will not be included the function graph generated by tracing so that
# we can be sure that we only initialize the TPU for embeddings exactly
# once.
with ops.init_scope():
self.build(output_shapes)
def _get_and_update_output_shapes_from_input(
self,
per_replica_input_shapes: Optional[List[TensorShape]] = None,
per_replica_batch_size: Optional[int] = None):
"""Get and update the per replica output shapes from the input."""
per_replica_output_shapes = None
if per_replica_batch_size and per_replica_input_shapes is None:
logging.warning(
"per_replica_batch_size argument will be deprecated, please specify "
"all the input shapes using per_replica_input_shapes argument.")
per_replica_output_shapes = self._get_output_shapes_from_batch_size(
per_replica_batch_size)
# Update the input shapes if provided.
if per_replica_input_shapes is not None:
if isinstance(per_replica_input_shapes, int):
logging.warning(
"Passing batch size to per_replica_input_shapes argument will be"
" deprecated, please specify all the input shapes using"
" per_replica_input_shapes argument.")
per_replica_output_shapes = self._get_output_shapes_from_batch_size(
per_replica_input_shapes)
else:
nest.assert_same_structure(
nest.flatten(per_replica_input_shapes),
nest.flatten(self._feature_config))
# Convert the nested structure to list.
per_replica_input_shapes = nest.flatten(per_replica_input_shapes)
per_replica_output_shapes = self._get_output_shapes_from_input_shapes(
per_replica_input_shapes)
if per_replica_output_shapes is not None:
# Check the output shapes with existing output shapes setting.
self._check_output_shapes(per_replica_output_shapes)
# Update the output shapes with existing output shapes setting.
# This is necessary Because the output shapes might be missing from
# the feature config, the usr can set it:
# 1. calling the build method
# 2. output shapes auto detected when calling the dequeue method for
# for the first time. The dequeue method will call build method
# with the output shapes.
# Either these two situations will lead to an update to the existing
# output shapes.
self._update_output_shapes(per_replica_output_shapes)
# Check if the output shapes are fully defined. This is required in order
# to set them in the feature descriptor field of the tpu embedding config
# proto.
self._check_output_shapes_fully_defined()
def _get_output_shapes_from_input_shapes(
self, input_shapes: List[TensorShape]) -> List[TensorShape]:
"""Get output shapes from the flattened input shapes list."""
output_shapes = []
for input_shape, feature in zip(input_shapes,
nest.flatten(self._feature_config)):
if input_shape.rank is None or input_shape.rank < 1:
raise ValueError(
"Received input tensor of shape {}. Rank must be 1 and above"
.format(input_shape))
# Update the input shape with the max sequence length. Only update when
# 1. Input feature is 2D ragged or sparse tensor.
# 2. Output shape is not set in the feature config and the max sequence
# length is set.
if (len(input_shape) == 2 and input_shape[-1] != 1 and
not feature.output_shape and feature.max_sequence_length > 0):
input_shape_list = input_shape.as_list()
input_shape_list.insert(
len(input_shape_list) - 1, feature.max_sequence_length)
input_shape = TensorShape(input_shape_list)
if input_shape.rank == 1:
output_shapes.append(input_shape)
else:
output_shapes.append(input_shape[:-1])
return output_shapes
@property
def embedding_tables(
self
) -> Dict[tpu_embedding_v2_utils.TableConfig, tf_variables.Variable]:
"""Returns a dict of embedding tables, keyed by `TableConfig`.
This property only works when the `TPUEmbedding` object is created under a
non-TPU strategy. This is intended to be used to for CPU based lookup when
creating a serving checkpoint.
Returns:
A dict of embedding tables, keyed by `TableConfig`.
Raises:
RuntimeError: If object was created under a `TPUStrategy`.
"""
# We don't support returning tables on TPU due to their sharded nature and
# the fact that when using a TPUStrategy:
# 1. Variables are stale and are only updated when a checkpoint is made.
# 2. Updating the variables won't affect the actual tables on the TPU.
if self._using_tpu:
if save_context.in_save_context():
return {table: self._variables[table.name]["parameters"].variables[0]
for table in self._table_config}
raise RuntimeError("Unable to retrieve embedding tables when using a TPU "
"strategy. If you need access, save your model, "
"create this object under a CPU strategy and restore.")
self._maybe_build(None)
# Only return the tables and not the slot variables. On CPU this are honest
# tf.Variables.
return {table: self._variables[table.name]["parameters"]
for table in self._table_config}
def _create_config_proto(
self
) -> tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration:
"""Creates the TPUEmbeddingConfiguration proto.
This proto is used to initialize the TPU embedding engine.
Returns:
A TPUEmbeddingConfiguration proto.
"""
config_proto = tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration()
# Map each callable dynamic learning rate to its in index in the list.
# The learning rate index is the index of the dynamic learning rate for this
# table (if it exists) in the list we created at initialization. We don't
# simply create one learning rate index per table as this has extremely bad
# performance characteristics. The more separate optimization configurations
# we have, the worse the performance will be.
learning_rate_index = {r: i for i, r in enumerate(
self._dynamic_learning_rates)}
for table in self._table_config:
table._set_table_descriptor( # pylint: disable=protected-access
config_proto.table_descriptor.add(),
self._strategy.extended.num_hosts,
learning_rate_index)
table_to_id = {table: i for i, table in enumerate(self._table_config)}
# Set feature descriptor field in the config proto.
for feature, output_shape in zip(
nest.flatten(self._feature_config), self._output_shapes):
feature_descriptor = config_proto.feature_descriptor.add()
if feature.name:
feature_descriptor.name = feature.name
feature_descriptor.table_id = table_to_id[feature.table]
# The input shape of the feature is the actual shape of the input tensor
# except the last dimension because the last dimension will always be
# reduced.
feature_descriptor.input_shape.extend(output_shape.as_list())
# Always set mode to training, we override the mode during enqueue.
config_proto.mode = (
tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.TRAINING)
config_proto.num_hosts = self._strategy.extended.num_hosts
config_proto.num_tensor_cores = self._strategy.num_replicas_in_sync
# TODO(bfontain): Allow users to pick MOD for the host sharding.
config_proto.sharding_strategy = (
tpu_embedding_configuration_pb2.TPUEmbeddingConfiguration.DIV_DEFAULT)
config_proto.pipeline_execution_with_tensor_core = (
self._pipeline_execution_with_tensor_core)
return config_proto
def apply_gradients(self, gradients, name: Optional[Text] = None):
"""Applies the gradient update to the embedding tables.
If a gradient of `None` is passed in any position of the nested structure,
then an gradient update with a zero gradient is applied for that feature.
For optimizers like SGD or Adagrad, this is the same as applying no update
at all. For lazy Adam and other sparsely applied optimizers with decay,
ensure you understand the effect of applying a zero gradient.
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
def training_step():
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
loss = ... # some computation involving activations
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features, ))
training_step()
```
Args:
gradients: A nested structure of gradients, with structure matching the
`feature_config` passed to this object.
name: A name for the underlying op.
Raises:
RuntimeError: If called when object wasn't created under a `TPUStrategy`
or if not built (either by manually calling build or calling enqueue).
ValueError: If a non-`tf.Tensor` non-`None` gradient is passed in, or a
`tf.Tensor` of the incorrect shape is passed in. Also if
the size of any sequence in `gradients` does not match corresponding
sequence in `feature_config`.
TypeError: If the type of any sequence in `gradients` does not match
corresponding sequence in `feature_config`.
"""
if not self._using_tpu:
raise RuntimeError("apply_gradients is not valid when TPUEmbedding "
"object is not created under a TPUStrategy.")
if not self._built:
raise RuntimeError("apply_gradients called on unbuilt TPUEmbedding "
"object. Please either call enqueue first or manually "
"call the build method.")
nest.assert_same_structure(self._feature_config, gradients)
updated_gradients = []
for (path, gradient), feature, output_shape in zip(
nest.flatten_with_joined_string_paths(gradients),
nest.flatten(self._feature_config), self._output_shapes):
full_output_shape = list(output_shape) + [feature.table.dim]
if gradient is not None and not isinstance(gradient, ops.Tensor):
raise ValueError(
f"found non-tensor type: {type(gradient)} at path {path}.")
if gradient is not None:
if gradient.shape != full_output_shape:
raise ValueError("Found gradient of shape {} at path {}. Expected "
"shape {}.".format(gradient.shape, path,
full_output_shape))
else:
# No gradient for this feature, since we must give a gradient for all
# features, pass in a zero tensor here. Note that this is not correct
# for all optimizers.
logging.warning(
"No gradient passed for feature %s, sending zero "
"gradient. This may not be correct behavior for certain "
"optimizers like Adam.", path)
gradient = array_ops.zeros(full_output_shape, dtype=dtypes.float32)
# Some gradients can be passed with op which shape is not correctly set.
# This ensures that the shape of the gradient is correctly set.
updated_gradients.append(
array_ops.reshape(gradient, shape=gradient.shape))
op = tpu_ops.send_tpu_embedding_gradients(
inputs=updated_gradients,
learning_rates=[
math_ops.cast(fn(), dtype=dtypes.float32)
for fn in self._dynamic_learning_rates
],
config=self._config_proto.SerializeToString())
# Apply the name tag to the op.
if name is not None:
_add_key_attr(op, name)
def dequeue(self, name: Optional[Text] = None):
"""Get the embedding results.
Returns a nested structure of `tf.Tensor` objects, matching the structure of
the `feature_config` argument to the `TPUEmbedding` class. The output shape
of the tensors is `(*output_shape, dim)`, `dim` is the dimension of the
corresponding `TableConfig`. For output_shape, there are three places where
it can be set.
1. FeatureConfig provided in the __init__ function.
2. Per_replica_output_shapes by directly calling the build method
after initializing the tpu embedding class.
3. Auto detected from the shapes of the input feature.
The priority of these places is the exact same order.
```python
strategy = tf.distribute.TPUStrategy(...)
with strategy.scope():
embedding = tf.tpu.experimental.embedding.TPUEmbedding(...)
distributed_dataset = (
strategy.distribute_datasets_from_function(
dataset_fn=...,
options=tf.distribute.InputOptions(
experimental_fetch_to_device=False))
dataset_iterator = iter(distributed_dataset)
@tf.function
def training_step():
def tpu_step(tpu_features):
with tf.GradientTape() as tape:
activations = embedding.dequeue()
tape.watch(activations)
loss = ... # some computation involving activations
embedding_gradients = tape.gradient(loss, activations)
embedding.apply_gradients(embedding_gradients)
embedding_features, tpu_features = next(dataset_iterator)
embedding.enqueue(embedding_features, training=True)
strategy.run(tpu_step, args=(tpu_features, ))
training_step()
```
Args:
name: A name for the underlying op.
Returns:
A nested structure of tensors, with the same structure as `feature_config`
passed to this instance of the `TPUEmbedding` object.
Raises:
RuntimeError: If called when object wasn't created under a `TPUStrategy`
or if not built (either by manually calling build or calling enqueue).
"""
if not self._using_tpu:
raise RuntimeError("dequeue is not valid when TPUEmbedding object is not "
"created under a TPUStrategy.")
if not self._built:
raise RuntimeError("dequeue called on unbuilt TPUEmbedding object. "
"Please either call enqueue first or manually call "
"the build method.")
# The activations returned by this op are per feature.
activations = tpu_ops.recv_tpu_embedding_activations(
num_outputs=len(self._config_proto.feature_descriptor),
config=self._config_proto.SerializeToString())
# Apply the name tag to the op.
if name is not None:
_add_key_attr(activations[0].op, name)
# Pack the list back into the same nested structure as the features.
return nest.pack_sequence_as(self._feature_config, activations)
def _create_variables_and_slots(
self
) -> Dict[Text, Dict[Text, tf_variables.Variable]]:
"""Create variables for TPU embeddings.
Note under TPUStrategy this will ensure that all creations happen within a
variable creation scope of the sharded variable creator.
Returns:
A dict of dicts. The outer dict is keyed by the table names and the inner
dicts are keyed by 'parameters' and the slot variable names.
"""
def create_variables(table):
"""Create all variables."""
variable_shape = (table.vocabulary_size, table.dim)
def getter(name, shape, dtype, initializer, trainable):
del shape
# _add_variable_with_custom_getter clears the shape sometimes, so we
# take the global shape from outside the getter.
initial_value = functools.partial(initializer, variable_shape,
dtype=dtype)
return tf_variables.Variable(
name=name,
initial_value=initial_value,
shape=variable_shape,
dtype=dtype,
trainable=trainable)
def variable_creator(name, initializer, trainable=True):
# use add_variable_with_custom_getter here so that we take advantage of
# the checkpoint loading to allow restore before the variables get
# created which avoids double initialization.
return self._add_variable_with_custom_getter(
name=name,
initializer=initializer,
shape=variable_shape,
dtype=dtypes.float32,
getter=getter,
trainable=trainable)
parameters = variable_creator(table.name, table.initializer,
trainable=not self._using_tpu)
def slot_creator(name, initializer):
return variable_creator(table.name + "/" + name,
initializer,
False)
if table.optimizer is not None:
slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
else:
slot_vars = {}
slot_vars["parameters"] = parameters
return slot_vars
# Store tables based on name rather than TableConfig as we can't track
# through dicts with non-string keys, i.e. we won't be able to save.
variables = {}
for table in self._table_config:
if not self._using_tpu:
variables[table.name] = create_variables(table)
else:
with variable_scope.variable_creator_scope(
make_sharded_variable_creator(self._hosts)):
variables[table.name] = create_variables(table)
return variables
def _load_variables(self):
# Only load the variables if we are:
# 1) Using TPU
# 2) Variables are created
# 3) Not in save context (except if running eagerly)
if self._using_tpu and self._built and not (
not context.executing_eagerly() and save_context.in_save_context()):
_load_variables_impl(self._config_proto.SerializeToString(),
self._hosts,
self._variables,
self._table_config)
def _retrieve_variables(self):
# Only retrieve the variables if we are:
# 1) Using TPU
# 2) Variables are created
# 3) Not in save context (except if running eagerly)
if self._using_tpu and self._built and not (
not context.executing_eagerly() and save_context.in_save_context()):
_retrieve_variables_impl(self._config_proto.SerializeToString(),
self._hosts,
self._variables,
self._table_config)
# Some helper functions for the below enqueue function.
def _add_data_for_tensor(self, tensor, weight, indices, values, weights,
int_zeros, float_zeros, path):
if weight is not None:
raise ValueError(
"Weight specified for dense input {}, which is not allowed. "
"Weight will always be 1 in this case.".format(path))
# For tensors, there are no indices and no weights.
indices.append(int_zeros)
values.append(math_ops.cast(array_ops.reshape(tensor, [-1]), dtypes.int64))
weights.append(float_zeros)
def _add_data_for_sparse_tensor(self, tensor, weight, indices, values,
weights, int_zeros, float_zeros, path,
feature):
sample_indices = math_ops.cast(tensor.indices, dtypes.int32)
if tensor.shape.rank == 2:
if not feature.output_shape and feature.max_sequence_length > 0:
# Add one dimension to the last axis.
sample_indices = array_ops.pad(
sample_indices, paddings=[[0, 0], [0, 1]])
else:
if feature.max_sequence_length > 0:
logging.warning(
(
"Input tensor is rank %d which is above 2, the"
" max_sequence_length setting will be ignored."
),
tensor.shape.rank,
)
indices.append(sample_indices)
values.append(math_ops.cast(tensor.values, dtypes.int64))
# If we have weights they must be a SparseTensor.
if weight is not None:
if not isinstance(weight, sparse_tensor.SparseTensor):
raise ValueError("Weight for {} is type {} which does not match "
"type input which is SparseTensor.".format(
path, type(weight)))
weights.append(math_ops.cast(weight.values, dtypes.float32))
else:
weights.append(float_zeros)
def _add_data_for_ragged_tensor(self, tensor, weight, row_splits, values,
weights, int_zeros, float_zeros, path,
feature):
row_splits.append(math_ops.cast(tensor.row_splits, dtypes.int32))
values.append(math_ops.cast(tensor.values, dtypes.int64))
# If we have weights they must be a RaggedTensor.
if weight is not None:
if not isinstance(weight, ragged_tensor.RaggedTensor):
raise ValueError("Weight for {} is type {} which does not match "
"type input which is RaggedTensor.".format(
path, type(weight)))
weights.append(math_ops.cast(weight.values, dtypes.float32))
else:
weights.append(float_zeros)
def _generate_enqueue_op(
self,
flat_inputs: List[internal_types.NativeObject],
flat_weights: List[Optional[internal_types.NativeObject]],
flat_features: List[tpu_embedding_v2_utils.FeatureConfig],
device_ordinal: int,
mode_override: Text
) -> ops.Operation:
"""Outputs a the enqueue op given the inputs and weights.
Args:
flat_inputs: A list of input tensors.
flat_weights: A list of input weights (or None) of the same length as
flat_inputs.
flat_features: A list of FeatureConfigs of the same length as flat_inputs.
device_ordinal: The device to create the enqueue op for.
mode_override: A tensor containing the string "train" or "inference".
Returns:
The enqueue op.
"""
# Combiners are per table, list in the same order as the table order.
combiners = [table.combiner for table in self._table_config]
# These parallel arrays will be the inputs to the enqueue op.
# sample_indices for sparse, row_splits for ragged.
indices_or_row_splits = []
values = []
weights = []
# We have to supply a empty/zero tensor in a list position where we don't
# have data (e.g. indices for standard Tensor input, weight when no weight
# is specified). We create one op here per call, so that we reduce the
# graph size.
int_zeros = array_ops.zeros((0,), dtype=dtypes.int32)
float_zeros = array_ops.zeros((0,), dtype=dtypes.float32)
# In the following loop we insert casts so that everything is either int32
# or float32. This is because op inputs which are lists of tensors must be
# of the same type within the list. Moreover the CPU implementations of
# these ops cast to these types anyway, so we don't lose any data by casting
# early.
for inp, weight, (path, feature) in zip(
flat_inputs, flat_weights, flat_features):
if isinstance(inp, ops.Tensor):
self._add_data_for_tensor(inp, weight, indices_or_row_splits, values,
weights, int_zeros, float_zeros, path)
elif isinstance(inp, sparse_tensor.SparseTensor):
self._add_data_for_sparse_tensor(inp, weight, indices_or_row_splits,
values, weights, int_zeros,
float_zeros, path, feature)
elif isinstance(inp, ragged_tensor.RaggedTensor):
self._add_data_for_ragged_tensor(inp, weight, indices_or_row_splits,
values, weights, int_zeros,
float_zeros, path, feature)
else:
raise ValueError("Input {} is of unknown type {}. Please only pass "
"Tensor, SparseTensor or RaggedTensor as input to "
"enqueue.".format(path, type(inp)))
return tpu_ops.enqueue_tpu_embedding_arbitrary_tensor_batch(
sample_indices_or_row_splits=indices_or_row_splits,
embedding_indices=values,
aggregation_weights=weights,
mode_override=mode_override,
device_ordinal=device_ordinal,
combiners=combiners)
def _raise_error_for_incorrect_control_flow_context(self):
"""Raises an error if we are not in the TPUReplicateContext."""
# Do not allow any XLA control flow (i.e. control flow in between a