-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
base.py
196 lines (151 loc) · 5.49 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Base decoders.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import functools
import six
import tensorflow.compat.v2 as tf
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core.utils import py_utils
@six.add_metaclass(abc.ABCMeta)
class Decoder(object):
"""Base decoder object.
`tfds.decode.Decoder` allows for overriding the default decoding by
implementing a subclass, or skipping it entirely with
`tfds.decode.SkipDecoding`.
Instead of subclassing, you can also create a `Decoder` from a function
with the `tfds.decode.make_decoder` decorator.
All decoders must derive from this base class. The implementation can
access the `self.feature` property which will correspond to the
`FeatureConnector` to which this decoder is applied.
To implement a decoder, the main method to override is `decode_example`,
which takes the serialized feature as input and returns the decoded feature.
If `decode_example` changes the output dtype, you must also override
the `dtype` property. This enables compatibility with
`tfds.features.Sequence`.
"""
def __init__(self):
self.feature = None
@api_utils.disallow_positional_args
def setup(self, feature):
"""Transformation contructor.
The initialization of decode object is deferred because the objects only
know the builder/features on which it is used after it has been
constructed, the initialization is done in this function.
Args:
feature: `tfds.features.FeatureConnector`, the feature to which is applied
this transformation.
"""
self.feature = feature
@property
def dtype(self):
"""Returns the `dtype` after decoding."""
tensor_info = self.feature.get_tensor_info()
return py_utils.map_nested(lambda t: t.dtype, tensor_info)
@abc.abstractmethod
def decode_example(self, serialized_example):
"""Decode the example feature field (eg: image).
Args:
serialized_example: `tf.Tensor` as decoded, the dtype/shape should be
identical to `feature.get_serialized_info()`
Returns:
example: Decoded example.
"""
raise NotImplementedError('Abstract class')
def decode_batch_example(self, serialized_example):
"""See `FeatureConnector.decode_batch_example` for details."""
return tf.map_fn(
self.decode_example,
serialized_example,
dtype=self.dtype,
parallel_iterations=10,
back_prop=False,
name='sequence_decode',
)
class SkipDecoding(Decoder):
"""Transformation which skip the decoding entirelly.
Example of usage:
```python
ds = ds.load(
'imagenet2012',
split='train',
decoders={
'image': tfds.decode.SkipDecoding(),
}
)
for ex in ds.take(1):
assert ex['image'].dtype == tf.string
```
"""
@property
def dtype(self):
tensor_info = self.feature.get_serialized_info()
return py_utils.map_nested(lambda t: t.dtype, tensor_info)
def decode_example(self, serialized_example):
"""Forward the serialized feature field."""
return serialized_example
class DecoderFn(Decoder):
"""Decoder created by `tfds.decoder.make_decoder` decorator."""
def __init__(self, fn, output_dtype, *args, **kwargs):
super(DecoderFn, self).__init__()
self._fn = fn
self._output_dtype = output_dtype
self._args = args
self._kwargs = kwargs
@property
def dtype(self):
if self._output_dtype is None:
return super(DecoderFn, self).dtype
else:
return self._output_dtype
def decode_example(self, serialized_example):
"""Decode the example using the function."""
return self._fn(
serialized_example, self.feature, *self._args, **self._kwargs)
def make_decoder(output_dtype=None):
"""Decorator to create a decoder.
The decorated function should have the signature `(example, feature, *args,
**kwargs) -> decoded_example`.
* `example`: Serialized example before decoding
* `feature`: `FeatureConnector` associated with the example
* `*args, **kwargs`: Optional additional kwargs forwarded to the function
Example:
```
@tfds.decode.make_decoder(output_dtype=tf.string)
def no_op_decoder(example, feature):
\"\"\"Decoder simply decoding feature normally.\"\"\"
return feature.decode_example(example)
tfds.load('mnist', split='train', decoders: {
'image': no_op_decoder(),
})
```
Args:
output_dtype: The output dtype after decoding. Required only if the decoded
example has a different type than the `FeatureConnector.dtype` and is
used to decode features inside sequences (ex: videos)
Returns:
The decoder object
""" # pylint: disable=g-docstring-has-escape
def decorator(fn):
@functools.wraps(fn)
def decorated(*args, **kwargs):
return DecoderFn(fn, output_dtype, *args, **kwargs)
return decorated
return decorator