-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtpu_embedding_base.py
145 lines (126 loc) · 5.94 KB
/
tpu_embedding_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base Class for TPU Embeddings Mid level APIs."""
import functools
from typing import Any, Dict, Iterable, Optional, Union, Text
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.tpu import tpu_embedding_v2_utils
from tensorflow.python.trackable import autotrackable
from tensorflow.python.util import nest
class TPUEmbeddingBase(autotrackable.AutoTrackable):
"""The TPUEmbedding Base class.
This class only contains the basic logic to check the feature config and table
config for the tpu embedding mid level APIs.
"""
def __init__(
self,
feature_config: Union[tpu_embedding_v2_utils.FeatureConfig, Iterable], # pylint:disable=g-bare-generic
optimizer: Optional[tpu_embedding_v2_utils._Optimizer] = None): # pylint:disable=protected-access
"""Creates the TPUEmbeddingBase object."""
self._feature_config = feature_config
self._output_shapes = []
for feature in nest.flatten(feature_config):
self._output_shapes.append(feature.output_shape)
# Set table order here to the order of the first occurrence of the table in
# a feature provided by the user. The order of this struct must be fixed
# to provide the user with deterministic behavior over multiple
# instantiations.
self._table_config = []
for feature in nest.flatten(feature_config):
if feature.table not in self._table_config:
self._table_config.append(feature.table)
# Ensure tables have unique names. Also error check the optimizer as we
# specifically don't do that in the TableConfig class to allow high level
# APIs that are built on this to use strings/other classes to represent
# optimizers (before they are passed to this class).
table_names = []
for i, table in enumerate(self._table_config):
if table.optimizer is None:
# TODO(bfontain) Should we allow some sort of optimizer merging here?
table.optimizer = optimizer
if (table.optimizer is not None and
not isinstance(table.optimizer, tpu_embedding_v2_utils._Optimizer)): # pylint: disable=protected-access
raise ValueError("{} is an unsupported optimizer class. Please pass an "
"instance of one of the optimizer classes under "
"tf.tpu.experimental.embedding.".format(
type(table.optimizer)))
if table.name is None:
table.name = "table_{}".format(i)
if table.name in table_names:
raise ValueError("Tables must have a unique name. "
f"Multiple tables with name {table.name} found.")
table_names.append(table.name)
self._built = False
@property
def embedding_tables(self):
"""Returns a dict of embedding tables, keyed by `TableConfig`."""
raise NotImplementedError
def _create_variables(self, table: tpu_embedding_v2_utils.TableConfig,
trainable: bool) -> Dict[Text, tf_variables.Variable]:
"""Create all variables including table variables and slot variables."""
variable_shape = (table.vocabulary_size, table.dim)
def getter(name, shape, dtype, initializer, trainable):
del shape
# _add_variable_with_custom_getter clears the shape sometimes, so we
# take the global shape from outside the getter.
initial_value = functools.partial(
initializer, variable_shape, dtype=dtype)
return tf_variables.Variable(
name=name,
initial_value=initial_value,
shape=variable_shape,
dtype=dtype,
trainable=trainable)
def variable_creator(name, initializer, trainable=True):
# Use add_variable_with_custom_getter here so that we take advantage of
# the checkpoint loading to allow restore before the variables get
# created which avoids double initialization.
return self._add_variable_with_custom_getter(
name=name,
initializer=initializer,
shape=variable_shape,
dtype=dtypes.float32,
getter=getter,
trainable=trainable)
parameters = variable_creator(
table.name, table.initializer, trainable=trainable)
def slot_creator(name, initializer):
return variable_creator(table.name + "/" + name, initializer, False)
if table.optimizer is not None:
slot_vars = table.optimizer._create_slots(parameters, slot_creator) # pylint: disable=protected-access
else:
slot_vars = {}
slot_vars["parameters"] = parameters
return slot_vars
def _create_variables_and_slots(self):
"""Create variables and slots variables for TPU embeddings."""
raise NotImplementedError
def build(self):
"""Create variables and slots variables for TPU embeddings."""
if self._built:
return
self._variables = self._create_variables_and_slots()
self._built = True
def __call__(self, features: Any, weights: Optional[Any] = None) -> Any:
"""Call the mid level api to do embedding lookup."""
if not self._built:
self.build()
return self.embedding_lookup(features, weights)
def embedding_lookup(self,
features: Any,
weights: Optional[Any] = None) -> Any:
"""Lookup the embedding table using the input features."""
raise NotImplementedError