/
lookup_ops.py
238 lines (208 loc) · 10 KB
/
lookup_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#==============================================================================
"""Lookup operations."""
from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
def _check_table_initializer_element_spec(element_spec):
"""Raises an error if the given table initializer element spec is invalid."""
base_error = ("Datasets used to initialize lookup tables must "
"produce elements in the form (key, value), where "
"the keys and values are scalar tensors. ")
specific_error = None
if len(element_spec) != 2:
raise ValueError(base_error + "However, the given dataset produces "
f"{len(element_spec)} components instead of two "
"(key, value) components. Full dataset element spec: "
f"{element_spec}.")
if not isinstance(element_spec[0], tensor_spec.TensorSpec):
raise ValueError(base_error + "However, the given dataset produces "
f"non-Tensor keys of type {type(element_spec[0])}.")
if not isinstance(element_spec[1], tensor_spec.TensorSpec):
raise ValueError(base_error + "However, the given dataset produces "
f"non-Tensor values of type {type(element_spec[1])}.")
if element_spec[0].shape.rank not in (None, 0):
raise ValueError(
base_error + "However, the given dataset produces "
f"non-scalar key Tensors of rank {element_spec[0].shape.rank}.")
if element_spec[1].shape.rank not in (None, 0):
raise ValueError(
base_error + "However, the given dataset produces "
f"non-scalar value Tensors of rank {element_spec[1].shape.rank}.")
@tf_export("data.experimental.DatasetInitializer")
class DatasetInitializer(lookup_ops.TableInitializerBase):
"""Creates a table initializer from a `tf.data.Dataset`.
Sample usage:
>>> keys = tf.data.Dataset.range(100)
>>> values = tf.data.Dataset.range(100).map(
... lambda x: tf.strings.as_string(x * 2))
>>> ds = tf.data.Dataset.zip((keys, values))
>>> init = tf.data.experimental.DatasetInitializer(ds)
>>> table = tf.lookup.StaticHashTable(init, "")
>>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
array([b'0', b'2', b'4'], dtype=object)
Attributes:
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
first scalar is treated as a key and the second as value.
Raises: ValueError if `dataset` doesn't conform to specifications.
"""
def __init__(self, dataset):
"""Creates a table initializer from a `tf.data.Dataset`.
Args:
dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
first scalar is treated as a key and the second as value.
Raises: ValueError if `dataset` doesn't conform to specifications.
Returns: A `DatasetInitializer` object
"""
# Assert that the dataset element spec is a tuple of TensorSpecs where
# each tensor is a scalar.
self.dataset = dataset
elem_spec = self.dataset.element_spec
_check_table_initializer_element_spec(elem_spec)
key_type = elem_spec[0].dtype
value_type = elem_spec[1].dtype
super(DatasetInitializer, self).__init__(key_type, value_type)
def initialize(self, table):
lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype)
init_op = ged_ops.initialize_table_from_dataset(
table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
@tf_export("data.experimental.table_from_dataset")
def table_from_dataset(dataset=None,
num_oov_buckets=0,
vocab_size=None,
default_value=None,
hasher_spec=lookup_ops.FastHashSpec,
key_dtype=dtypes.string,
name=None):
"""Returns a lookup table based on the given dataset.
This operation constructs a lookup table based on the given dataset of pairs
of (key, value).
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
`default_value`.
The bucket ID range is
`[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
Sample Usages:
>>> keys = tf.data.Dataset.range(100)
>>> values = tf.data.Dataset.range(100).map(
... lambda x: tf.strings.as_string(x * 2))
>>> ds = tf.data.Dataset.zip((keys, values))
>>> table = tf.data.experimental.table_from_dataset(
... ds, default_value='n/a', key_dtype=tf.int64)
>>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
array([b'0', b'2', b'4'], dtype=object)
Args:
dataset: A dataset containing (key, value) pairs.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets.
key_dtype: The `key` data type.
name: A name for this op (optional).
Returns:
The lookup table based on the given dataset.
Raises:
ValueError: If
* `dataset` does not contain pairs
* The 2nd item in the `dataset` pairs has a dtype which is incompatible
with `default_value`
* `num_oov_buckets` is negative
* `vocab_size` is not greater than zero
* The `key_dtype` is not integer or string
"""
elem_spec = dataset.element_spec
_check_table_initializer_element_spec(elem_spec)
if default_value is None:
default_value = -1
if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
raise ValueError("`default_value` must be specified when creating a "
"table from a dataset that produces values of type "
f"{elem_spec[1].dtype}.")
if num_oov_buckets < 0:
raise ValueError("`num_oov_buckets` must be greater than or equal to 0, "
f"got {num_oov_buckets}.")
if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
vocab_size < 1):
raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
raise TypeError("`key_dtype` must be either an integer or string type, "
f"but got {key_dtype}")
if vocab_size is not None:
if isinstance(vocab_size, ops.Tensor):
vocab_size = math_ops.cast(vocab_size, dtypes.int64)
dataset = dataset.take(vocab_size)
dataset = dataset.apply(assert_cardinality(vocab_size))
with ops.name_scope(name, "string_to_index"):
initializer = DatasetInitializer(dataset)
with ops.name_scope(None, "hash_table"):
table = lookup_ops.StaticHashTableV1(initializer, default_value)
if num_oov_buckets:
table = lookup_ops.IdTableWithHashBuckets(
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
key_dtype=key_dtype)
return table
@tf_export("data.experimental.index_table_from_dataset")
def index_table_from_dataset(dataset=None,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
hasher_spec=lookup_ops.FastHashSpec,
key_dtype=dtypes.string,
name=None):
"""Returns an index lookup table based on the given dataset.
This operation constructs a lookup table based on the given dataset of keys.
Any lookup of an out-of-vocabulary token will return a bucket ID based on its
hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
`default_value`.
The bucket ID range is
`[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
Sample Usages:
>>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2))
>>> table = tf.data.experimental.index_table_from_dataset(
... ds, key_dtype=dtypes.int64)
>>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy()
array([0, 1, 2])
Args:
dataset: A dataset of keys.
num_oov_buckets: The number of out-of-vocabulary buckets.
vocab_size: Number of the elements in the vocabulary, if known.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets.
key_dtype: The `key` data type.
name: A name for this op (optional).
Returns:
The lookup table based on the given dataset.
Raises:
ValueError: If
* `num_oov_buckets` is negative
* `vocab_size` is not greater than zero
* The `key_dtype` is not integer or string
"""
return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)),
num_oov_buckets, vocab_size, default_value,
hasher_spec, key_dtype, name)