/
parsing_ops.py
2220 lines (1938 loc) · 93.6 KB
/
parsing_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Parsing Ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_parsing_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import,undefined-variable
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
ops.NotDifferentiable("DecodeRaw")
ops.NotDifferentiable("DecodePaddedRaw")
ops.NotDifferentiable("ParseTensor")
ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
@tf_export("io.VarLenFeature", v1=["VarLenFeature", "io.VarLenFeature"])
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
Fields:
dtype: Data type of input.
"""
pass
@tf_export("io.SparseFeature", v1=["io.SparseFeature", "SparseFeature"])
class SparseFeature(
collections.namedtuple(
"SparseFeature",
["index_key", "value_key", "dtype", "size", "already_sorted"])):
"""Configuration for parsing a sparse input feature from an `Example`.
Note, preferably use `VarLenFeature` (possibly in combination with a
`SequenceExample`) in order to parse out `SparseTensor`s instead of
`SparseFeature` due to its simplicity.
Closely mimicking the `SparseTensor` that will be obtained by parsing an
`Example` with a `SparseFeature` config, a `SparseFeature` contains a
* `value_key`: The name of key for a `Feature` in the `Example` whose parsed
`Tensor` will be the resulting `SparseTensor.values`.
* `index_key`: A list of names - one for each dimension in the resulting
`SparseTensor` whose `indices[i][dim]` indicating the position of
the `i`-th value in the `dim` dimension will be equal to the `i`-th value in
the Feature with key named `index_key[dim]` in the `Example`.
* `size`: A list of ints for the resulting `SparseTensor.dense_shape`.
For example, we can represent the following 2D `SparseTensor`
```python
SparseTensor(indices=[[3, 1], [20, 0]],
values=[0.5, -1.0]
dense_shape=[100, 3])
```
with an `Example` input proto
```python
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix0" value { int64_list { value: [ 3, 20 ] } } }
feature { key: "ix1" value { int64_list { value: [ 1, 0 ] } } }
}
```
and `SparseFeature` config with 2 `index_key`s
```python
SparseFeature(index_key=["ix0", "ix1"],
value_key="val",
dtype=tf.float32,
size=[100, 3])
```
Fields:
index_key: A single string name or a list of string names of index features.
For each key the underlying feature's type must be `int64` and its length
must always match that of the `value_key` feature.
To represent `SparseTensor`s with a `dense_shape` of `rank` higher than 1
a list of length `rank` should be used.
value_key: Name of value feature. The underlying feature's type must
be `dtype` and its length must always match that of all the `index_key`s'
features.
dtype: Data type of the `value_key` feature.
size: A Python int or list thereof specifying the dense shape. Should be a
list if and only if `index_key` is a list. In that case the list must be
equal to the length of `index_key`. Each for each entry `i` all values in
the `index_key`[i] feature must be in `[0, size[i])`.
already_sorted: A Python boolean to specify whether the values in
`value_key` are already sorted by their index position. If so skip
sorting. False by default (optional).
"""
def __new__(cls, index_key, value_key, dtype, size, already_sorted=False):
return super(SparseFeature, cls).__new__(
cls, index_key, value_key, dtype, size, already_sorted)
@tf_export("io.FixedLenFeature", v1=["io.FixedLenFeature", "FixedLenFeature"])
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
To treat sparse input as dense, provide a `default_value`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data.
dtype: Data type of input.
default_value: Value to be used if an example is missing this feature. It
must be compatible with `dtype` and of the specified `shape`.
"""
def __new__(cls, shape, dtype, default_value=None):
return super(FixedLenFeature, cls).__new__(
cls, shape, dtype, default_value)
@tf_export("io.FixedLenSequenceFeature",
v1=["io.FixedLenSequenceFeature", "FixedLenSequenceFeature"])
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
"""Configuration for parsing a variable-length input feature into a `Tensor`.
The resulting `Tensor` of parsing a single `SequenceExample` or `Example` has
a static `shape` of `[None] + shape` and the specified `dtype`.
The resulting `Tensor` of parsing a `batch_size` many `Example`s has
a static `shape` of `[batch_size, None] + shape` and the specified `dtype`.
The entries in the `batch` from different `Examples` will be padded with
`default_value` to the maximum length present in the `batch`.
To treat a sparse input as dense, provide `allow_missing=True`; otherwise,
the parse functions will fail on any examples missing this feature.
Fields:
shape: Shape of input data for dimension 2 and higher. First dimension is
of variable length `None`.
dtype: Data type of input.
allow_missing: Whether to allow this feature to be missing from a feature
list item. Is available only for parsing `SequenceExample` not for
parsing `Examples`.
default_value: Scalar value to be used to pad multiple `Example`s to their
maximum length. Irrelevant for parsing a single `Example` or
`SequenceExample`. Defaults to "" for dtype string and 0 otherwise
(optional).
"""
def __new__(cls, shape, dtype, allow_missing=False, default_value=None):
return super(FixedLenSequenceFeature, cls).__new__(
cls, shape, dtype, allow_missing, default_value)
def _features_to_raw_params(features, types):
"""Split feature tuples into raw params used by `gen_parsing_ops`.
Args:
features: A `dict` mapping feature keys to objects of a type in `types`.
types: Type of features to allow, among `FixedLenFeature`, `VarLenFeature`,
`SparseFeature`, and `FixedLenSequenceFeature`.
Returns:
Tuple of `sparse_keys`, `sparse_types`, `dense_keys`, `dense_types`,
`dense_defaults`, `dense_shapes`.
Raises:
ValueError: if `features` contains an item not in `types`, or an invalid
feature.
"""
sparse_keys = []
sparse_types = []
dense_keys = []
dense_types = []
# When the graph is built twice, multiple dense_defaults in a normal dict
# could come out in different orders. This will fail the _e2e_test which
# expects exactly the same graph.
# OrderedDict which preserves the order can solve the problem.
dense_defaults = collections.OrderedDict()
dense_shapes = []
if features:
# NOTE: We iterate over sorted keys to keep things deterministic.
for key in sorted(features.keys()):
feature = features[key]
if isinstance(feature, VarLenFeature):
if VarLenFeature not in types:
raise ValueError("Unsupported VarLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
sparse_keys.append(key)
sparse_types.append(feature.dtype)
elif isinstance(feature, SparseFeature):
if SparseFeature not in types:
raise ValueError("Unsupported SparseFeature %s." % (feature,))
if not feature.index_key:
raise ValueError(
"Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError(
"Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
if isinstance(index_keys, str):
index_keys = [index_keys]
elif len(index_keys) > 1:
tf_logging.warning("SparseFeature is a complicated feature config "
"and should only be used after careful "
"consideration of VarLenFeature.")
for index_key in sorted(index_keys):
if index_key in sparse_keys:
dtype = sparse_types[sparse_keys.index(index_key)]
if dtype != dtypes.int64:
raise ValueError("Conflicting type %s vs int64 for feature %s." %
(dtype, index_key))
else:
sparse_keys.append(index_key)
sparse_types.append(dtypes.int64)
if feature.value_key in sparse_keys:
dtype = sparse_types[sparse_keys.index(feature.value_key)]
if dtype != feature.dtype:
raise ValueError("Conflicting type %s vs %s for feature %s." % (
dtype, feature.dtype, feature.value_key))
else:
sparse_keys.append(feature.value_key)
sparse_types.append(feature.dtype)
elif isinstance(feature, FixedLenFeature):
if FixedLenFeature not in types:
raise ValueError("Unsupported FixedLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
feature_tensor_shape = tensor_shape.as_shape(feature.shape)
if (feature.shape and feature_tensor_shape.ndims and
feature_tensor_shape.dims[0].value is None):
raise ValueError("First dimension of shape for feature %s unknown. "
"Consider using FixedLenSequenceFeature." % key)
if (feature.shape is not None and
not feature_tensor_shape.is_fully_defined()):
raise ValueError("All dimensions of shape for feature %s need to be "
"known but received %s." % (key, str(feature.shape)))
dense_keys.append(key)
dense_shapes.append(feature.shape)
dense_types.append(feature.dtype)
if feature.default_value is not None:
dense_defaults[key] = feature.default_value
elif isinstance(feature, FixedLenSequenceFeature):
if FixedLenSequenceFeature not in types:
raise ValueError("Unsupported FixedLenSequenceFeature %s." % (
feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
raise ValueError("Missing shape for feature %s." % key)
dense_keys.append(key)
dense_shapes.append(feature.shape)
dense_types.append(feature.dtype)
if feature.allow_missing:
dense_defaults[key] = None
if feature.default_value is not None:
dense_defaults[key] = feature.default_value
else:
raise ValueError("Invalid feature %s:%s." % (key, feature))
return (
sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes)
def _construct_sparse_tensors_for_sparse_features(features, tensor_dict):
"""Merges SparseTensors of indices and values of SparseFeatures.
Constructs new dict based on `tensor_dict`. For `SparseFeatures` in the values
of `features` expects their `index_key`s and `index_value`s to be present in
`tensor_dict` mapping to `SparseTensor`s. Constructs a single `SparseTensor`
from them, and adds it to the result with the key from `features`.
Copies other keys and values from `tensor_dict` with keys present in
`features`.
Args:
features: A `dict` mapping feature keys to `SparseFeature` values.
Values of other types will be ignored.
tensor_dict: A `dict` mapping feature keys to `Tensor` and `SparseTensor`
values. Expected to contain keys of the `SparseFeature`s' `index_key`s and
`value_key`s and mapping them to `SparseTensor`s.
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values. Similar
to `tensor_dict` except each `SparseFeature`s in `features` results in a
single `SparseTensor`.
"""
tensor_dict = dict(tensor_dict) # Do not modify argument passed in.
# Construct SparseTensors for SparseFeatures.
for key in sorted(features.keys()):
feature = features[key]
if isinstance(feature, SparseFeature):
if isinstance(feature.index_key, str):
sp_ids = tensor_dict[feature.index_key]
else:
sp_ids = [tensor_dict[index_key] for index_key in feature.index_key]
sp_values = tensor_dict[feature.value_key]
tensor_dict[key] = sparse_ops.sparse_merge(
sp_ids,
sp_values,
vocab_size=feature.size,
already_sorted=feature.already_sorted)
# Remove tensors from dictionary that were only used to construct
# SparseTensors for SparseFeature.
for key in set(tensor_dict) - set(features):
del tensor_dict[key]
return tensor_dict
def _prepend_none_dimension(features):
if features:
modified_features = dict(features) # Create a copy to modify
for key, feature in features.items():
if isinstance(feature, FixedLenSequenceFeature):
if not feature.allow_missing:
raise ValueError("Unsupported: FixedLenSequenceFeature requires "
"allow_missing to be True.")
modified_features[key] = FixedLenSequenceFeature(
[None] + list(feature.shape),
feature.dtype,
feature.allow_missing,
feature.default_value)
return modified_features
else:
return features
@tf_export(v1=["io.parse_example", "parse_example"])
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
Parses a number of serialized [`Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
protos given in `serialized`. We refer to `serialized` as a batch with
`batch_size` many entries of individual `Example` protos.
`example_names` may contain descriptive names for the corresponding serialized
protos. These may be useful for debugging purposes, but they have no effect on
the output. If not `None`, `example_names` must be the same length as
`serialized`.
This op parses serialized examples into a dictionary mapping keys to `Tensor`
and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
`SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
and `SparseFeature` is mapped to a `SparseTensor`, and each
`FixedLenFeature` is mapped to a `Tensor`.
Each `VarLenFeature` maps to a `SparseTensor` of the specified type
representing a ragged matrix. Its indices are `[batch, index]` where `batch`
identifies the example in `serialized`, and `index` is the value's index in
the list of values associated with that feature and example.
Each `SparseFeature` maps to a `SparseTensor` of the specified type
representing a Tensor of `dense_shape` `[batch_size] + SparseFeature.size`.
Its `values` come from the feature in the examples with key `value_key`.
A `values[i]` comes from a position `k` in the feature of an example at batch
entry `batch`. This positional information is recorded in `indices[i]` as
`[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
the feature in the example at with key `SparseFeature.index_key[j]`.
In other words, we split the indices (except the first index indicating the
batch entry) of a `SparseTensor` by dimension into different features of the
`Example`. Due to its complexity a `VarLenFeature` should be preferred over a
`SparseFeature` whenever possible.
Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or
`tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`.
`FixedLenFeature` entries with a `default_value` are optional. With no default
value, we will fail if that `Feature` is missing from any example in
`serialized`.
Each `FixedLenSequenceFeature` `df` maps to a `Tensor` of the specified type
(or `tf.float32` if not specified) and shape
`(serialized.size(), None) + df.shape`.
All examples in `serialized` will be padded with `default_value` along the
second dimension.
Examples:
For example, if one expects a `tf.float32` `VarLenFeature` `ft` and three
serialized `Example`s are provided:
```
serialized = [
features
{ feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
features
{ feature []},
features
{ feature { key: "ft" value { float_list { value: [3.0] } } }
]
```
then the output will look like:
```python
{"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
values=[1.0, 2.0, 3.0],
dense_shape=(3, 2)) }
```
If instead a `FixedLenSequenceFeature` with `default_value = -1.0` and
`shape=[]` is used then the output will look like:
```python
{"ft": [[1.0, 2.0], [3.0, -1.0]]}
```
Given two `Example` input protos in `serialized`:
```
[
features {
feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
feature { key: "gps" value { float_list { value: [] } } }
},
features {
feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
feature { key: "dank" value { int64_list { value: [ 42 ] } } }
feature { key: "gps" value { } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float32),
}
```
Then the output is a dictionary:
```python
{
"kw": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["knit", "big", "emmy"]
dense_shape=[2, 2]),
"dank": SparseTensor(
indices=[[1, 0]],
values=[42],
dense_shape=[2, 1]),
"gps": SparseTensor(
indices=[],
values=[],
dense_shape=[2, 0]),
}
```
For dense results in two serialized `Example`s:
```
[
features {
feature { key: "age" value { int64_list { value: [ 0 ] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
},
features {
feature { key: "age" value { int64_list { value: [] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
}
]
```
We can use arguments:
```
example_names: ["input0", "input1"],
features: {
"age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"gender": FixedLenFeature([], dtype=tf.string),
}
```
And the expected output is:
```python
{
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
}
```
An alternative to `VarLenFeature` to obtain a `SparseTensor` is
`SparseFeature`. For example, given two `Example` input protos in
`serialized`:
```
[
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } }
},
features {
feature { key: "val" value { float_list { value: [ 0.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 42 ] } } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"sparse": SparseFeature(
index_key="ix", value_key="val", dtype=tf.float32, size=100),
}
```
Then the output is a dictionary:
```python
{
"sparse": SparseTensor(
indices=[[0, 3], [0, 20], [1, 42]],
values=[0.5, -1.0, 0.0]
dense_shape=[2, 100]),
}
```
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
features: A `dict` mapping feature keys to `FixedLenFeature`,
`VarLenFeature`, and `SparseFeature` values.
name: A name for this operation (optional).
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
Raises:
ValueError: if any feature is invalid.
"""
return parse_example_v2(serialized, features, example_names, name)
@tf_export("io.parse_example", v1=[])
def parse_example_v2(serialized, features, example_names=None, name=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
Parses a number of serialized [`Example`](https://www.tensorflow.org/code/tensorflow/core/example/example.proto)
protos given in `serialized`. We refer to `serialized` as a batch with
`batch_size` many entries of individual `Example` protos.
`example_names` may contain descriptive names for the corresponding serialized
protos. These may be useful for debugging purposes, but they have no effect on
the output. If not `None`, `example_names` must be the same length as
`serialized`.
This op parses serialized examples into a dictionary mapping keys to `Tensor`
and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
`SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
and `SparseFeature` is mapped to a `SparseTensor`, and each
`FixedLenFeature` is mapped to a `Tensor`.
Each `VarLenFeature` maps to a `SparseTensor` of the specified type
representing a ragged matrix. Its indices are `[batch, index]` where `batch`
identifies the example in `serialized`, and `index` is the value's index in
the list of values associated with that feature and example.
Each `SparseFeature` maps to a `SparseTensor` of the specified type
representing a Tensor of `dense_shape` `[batch_size] + SparseFeature.size`.
Its `values` come from the feature in the examples with key `value_key`.
A `values[i]` comes from a position `k` in the feature of an example at batch
entry `batch`. This positional information is recorded in `indices[i]` as
`[batch, index_0, index_1, ...]` where `index_j` is the `k-th` value of
the feature in the example at with key `SparseFeature.index_key[j]`.
In other words, we split the indices (except the first index indicating the
batch entry) of a `SparseTensor` by dimension into different features of the
`Example`. Due to its complexity a `VarLenFeature` should be preferred over a
`SparseFeature` whenever possible.
Each `FixedLenFeature` `df` maps to a `Tensor` of the specified type (or
`tf.float32` if not specified) and shape `(serialized.size(),) + df.shape`.
`FixedLenFeature` entries with a `default_value` are optional. With no default
value, we will fail if that `Feature` is missing from any example in
`serialized`.
Each `FixedLenSequenceFeature` `df` maps to a `Tensor` of the specified type
(or `tf.float32` if not specified) and shape
`(serialized.size(), None) + df.shape`.
All examples in `serialized` will be padded with `default_value` along the
second dimension.
Examples:
For example, if one expects a `tf.float32` `VarLenFeature` `ft` and three
serialized `Example`s are provided:
```
serialized = [
features
{ feature { key: "ft" value { float_list { value: [1.0, 2.0] } } } },
features
{ feature []},
features
{ feature { key: "ft" value { float_list { value: [3.0] } } }
]
```
then the output will look like:
```python
{"ft": SparseTensor(indices=[[0, 0], [0, 1], [2, 0]],
values=[1.0, 2.0, 3.0],
dense_shape=(3, 2)) }
```
If instead a `FixedLenSequenceFeature` with `default_value = -1.0` and
`shape=[]` is used then the output will look like:
```python
{"ft": [[1.0, 2.0], [3.0, -1.0]]}
```
Given two `Example` input protos in `serialized`:
```
[
features {
feature { key: "kw" value { bytes_list { value: [ "knit", "big" ] } } }
feature { key: "gps" value { float_list { value: [] } } }
},
features {
feature { key: "kw" value { bytes_list { value: [ "emmy" ] } } }
feature { key: "dank" value { int64_list { value: [ 42 ] } } }
feature { key: "gps" value { } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float32),
}
```
Then the output is a dictionary:
```python
{
"kw": SparseTensor(
indices=[[0, 0], [0, 1], [1, 0]],
values=["knit", "big", "emmy"]
dense_shape=[2, 2]),
"dank": SparseTensor(
indices=[[1, 0]],
values=[42],
dense_shape=[2, 1]),
"gps": SparseTensor(
indices=[],
values=[],
dense_shape=[2, 0]),
}
```
For dense results in two serialized `Example`s:
```
[
features {
feature { key: "age" value { int64_list { value: [ 0 ] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
},
features {
feature { key: "age" value { int64_list { value: [] } } }
feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
}
]
```
We can use arguments:
```
example_names: ["input0", "input1"],
features: {
"age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
"gender": FixedLenFeature([], dtype=tf.string),
}
```
And the expected output is:
```python
{
"age": [[0], [-1]],
"gender": [["f"], ["f"]],
}
```
An alternative to `VarLenFeature` to obtain a `SparseTensor` is
`SparseFeature`. For example, given two `Example` input protos in
`serialized`:
```
[
features {
feature { key: "val" value { float_list { value: [ 0.5, -1.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 3, 20 ] } } }
},
features {
feature { key: "val" value { float_list { value: [ 0.0 ] } } }
feature { key: "ix" value { int64_list { value: [ 42 ] } } }
}
]
```
And arguments
```
example_names: ["input0", "input1"],
features: {
"sparse": SparseFeature(
index_key="ix", value_key="val", dtype=tf.float32, size=100),
}
```
Then the output is a dictionary:
```python
{
"sparse": SparseTensor(
indices=[[0, 3], [0, 20], [1, 42]],
values=[0.5, -1.0, 0.0]
dense_shape=[2, 100]),
}
```
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
features: A `dict` mapping feature keys to `FixedLenFeature`,
`VarLenFeature`, and `SparseFeature` values.
example_names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos in the batch.
name: A name for this operation (optional).
Returns:
A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
Raises:
ValueError: if any feature is invalid.
"""
if not features:
raise ValueError("Missing: features was %s." % features)
features = _prepend_none_dimension(features)
(sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
dense_shapes) = _features_to_raw_params(
features,
[VarLenFeature, SparseFeature, FixedLenFeature, FixedLenSequenceFeature])
outputs = _parse_example_raw(
serialized, example_names, sparse_keys, sparse_types, dense_keys,
dense_types, dense_defaults, dense_shapes, name)
return _construct_sparse_tensors_for_sparse_features(features, outputs)
def _parse_example_raw(serialized,
names=None,
sparse_keys=None,
sparse_types=None,
dense_keys=None,
dense_types=None,
dense_defaults=None,
dense_shapes=None,
name=None):
"""Parses `Example` protos.
Args:
serialized: A vector (1-D Tensor) of strings, a batch of binary
serialized `Example` protos.
names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos.
sparse_keys: A list of string keys in the examples' features.
The results for these keys will be returned as `SparseTensor` objects.
sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
dense_keys: A list of string keys in the examples' features.
The results for these keys will be returned as `Tensor`s
dense_types: A list of DTypes of the same length as `dense_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
dense_defaults: A dict mapping string keys to `Tensor`s.
The keys of the dict must match the dense_keys of the feature.
dense_shapes: A list of tuples with the same length as `dense_keys`.
The shape of the data for each dense feature referenced by `dense_keys`.
Required for any input tensors identified by `dense_keys`. Must be
either fully defined, or may contain an unknown first dimension.
An unknown first dimension means the feature is treated as having
a variable number of blocks, and the output shape along this dimension
is considered unknown at graph build time. Padding is applied for
minibatch elements smaller than the maximum number of blocks for the
given feature along this dimension.
name: A name for this operation (optional).
Returns:
A `dict` mapping keys to `Tensor`s and `SparseTensor`s.
"""
with ops.name_scope(name, "ParseExample", [serialized, names]):
(names, dense_defaults_vec, sparse_keys, sparse_types,
dense_keys, dense_shapes, _) = _process_raw_parameters(
names, dense_defaults, sparse_keys, sparse_types, dense_keys,
dense_types, dense_shapes)
outputs = gen_parsing_ops.parse_example(
serialized=serialized,
names=names,
dense_defaults=dense_defaults_vec,
sparse_keys=sparse_keys,
sparse_types=sparse_types,
dense_keys=dense_keys,
dense_shapes=dense_shapes,
name=name)
(sparse_indices, sparse_values, sparse_shapes, dense_values) = outputs
sparse_tensors = [
sparse_tensor.SparseTensor(ix, val, shape) for (ix, val, shape)
in zip(sparse_indices, sparse_values, sparse_shapes)]
return dict(zip(sparse_keys + dense_keys, sparse_tensors + dense_values))
def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
dense_keys, dense_types, dense_shapes):
"""Process raw parameters to params used by `gen_parsing_ops`.
Args:
names: A vector (1-D Tensor) of strings (optional), the names of
the serialized protos.
dense_defaults: A dict mapping string keys to `Tensor`s.
The keys of the dict must match the dense_keys of the feature.
sparse_keys: A list of string keys in the examples' features.
The results for these keys will be returned as `SparseTensor` objects.
sparse_types: A list of `DTypes` of the same length as `sparse_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
dense_keys: A list of string keys in the examples' features.
The results for these keys will be returned as `Tensor`s
dense_types: A list of DTypes of the same length as `dense_keys`.
Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
and `tf.string` (`BytesList`) are supported.
dense_shapes: A list of tuples with the same length as `dense_keys`.
The shape of the data for each dense feature referenced by `dense_keys`.
Required for any input tensors identified by `dense_keys`. Must be
either fully defined, or may contain an unknown first dimension.
An unknown first dimension means the feature is treated as having
a variable number of blocks, and the output shape along this dimension
is considered unknown at graph build time. Padding is applied for
minibatch elements smaller than the maximum number of blocks for the
given feature along this dimension.
Returns:
Tuple of `names`, `dense_defaults_vec`, `sparse_keys`, `sparse_types`,
`dense_keys`, `dense_shapes`.
Raises:
ValueError: If sparse and dense key sets intersect, or input lengths do not
match up.
"""
names = [] if names is None else names
dense_defaults = collections.OrderedDict(
) if dense_defaults is None else dense_defaults
sparse_keys = [] if sparse_keys is None else sparse_keys
sparse_types = [] if sparse_types is None else sparse_types
dense_keys = [] if dense_keys is None else dense_keys
dense_types = [] if dense_types is None else dense_types
dense_shapes = ([[]] * len(dense_keys)
if dense_shapes is None else dense_shapes)
num_dense = len(dense_keys)
num_sparse = len(sparse_keys)
if len(dense_shapes) != num_dense:
raise ValueError("len(dense_shapes) != len(dense_keys): %d vs. %d" %
(len(dense_shapes), num_dense))
if len(dense_types) != num_dense:
raise ValueError("len(dense_types) != len(num_dense): %d vs. %d" %
(len(dense_types), num_dense))
if len(sparse_types) != num_sparse:
raise ValueError("len(sparse_types) != len(sparse_keys): %d vs. %d" %
(len(sparse_types), num_sparse))
if num_dense + num_sparse == 0:
raise ValueError("Must provide at least one sparse key or dense key")
if not set(dense_keys).isdisjoint(set(sparse_keys)):
raise ValueError(
"Dense and sparse keys must not intersect; intersection: %s" %
set(dense_keys).intersection(set(sparse_keys)))
# Convert dense_shapes to TensorShape object.
dense_shapes = [tensor_shape.as_shape(shape) for shape in dense_shapes]
dense_defaults_vec = []
for i, key in enumerate(dense_keys):
default_value = dense_defaults.get(key)
dense_shape = dense_shapes[i]
if (dense_shape.ndims is not None and dense_shape.ndims > 0 and
dense_shape.dims[0].value is None):
# Variable stride dense shape, the default value should be a
# scalar padding value
if default_value is None:
default_value = ops.convert_to_tensor(
"" if dense_types[i] == dtypes.string else 0, dtype=dense_types[i])
else:
# Reshape to a scalar to ensure user gets an error if they
# provide a tensor that's not intended to be a padding value
# (0 or 2+ elements).
key_name = "padding_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dense_types[i], name=key_name)
default_value = array_ops.reshape(default_value, [])
else:
if default_value is None:
default_value = constant_op.constant([], dtype=dense_types[i])
elif not isinstance(default_value, ops.Tensor):
key_name = "key_" + re.sub("[^A-Za-z0-9_.\\-/]", "_", key)
default_value = ops.convert_to_tensor(
default_value, dtype=dense_types[i], name=key_name)
default_value = array_ops.reshape(default_value, dense_shape)
dense_defaults_vec.append(default_value)
# Finally, convert dense_shapes to TensorShapeProto
dense_shapes_as_proto = [shape.as_proto() for shape in dense_shapes]
return (names, dense_defaults_vec, sparse_keys, sparse_types, dense_keys,
dense_shapes_as_proto, dense_shapes)
@tf_export(v1=["io.parse_single_example", "parse_single_example"])
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
Similar to `parse_example`, except:
For dense tensors, the returned `Tensor` is identical to the output of
`parse_example`, except there is no batch dimension, the output shape is the
same as the shape given in `dense_shape`.
For `SparseTensor`s, the first (batch) column of the indices matrix is removed
(the indices matrix is a column vector), the values vector is unchanged, and
the first (`batch_size`) entry of the shape vector is removed (it is now a
single element vector).
One might see performance advantages by batching `Example` protos with