-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathpreprocess_spec.py
383 lines (314 loc) · 13.9 KB
/
preprocess_spec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
# Copyright 2024 The CLU Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Library for parsing a preprocessing spec.
A preprocessing spec is a list of preprocessing ops separated by '|' that can be
applied sequentially as a preprocessing function. The preprocessing ops are
provided as input and must implement the PreprocessOp protocol. While not
strictly required we also recommend annotating preprocess ops as dataclasses.
By convention the preprocessing function operates on dictionaries of features.
Each op can change the dictionary by modifying, adding or removing dictionary
entries. Dictionary entries should be tensors, keys should be strings.
(For common data types we recommend using the feature keys used in TFDS.)
Example spec: 'fn1|fn2(3)|fn3(keyword=5)'
This will construct the following preprocessing function:
def preprocess_fn(features: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
features = fn1(features)
features = fn2(features, 3)
features = fn3(features, keyword=5)
return features
See preprocess_spec_test.py for some simple examples.
"""
import abc
import ast
import dataclasses
import inspect
import re
import sys
from typing import Dict, List, Sequence, Tuple, Type, TypeVar, Union
from absl import logging
from flax import traverse_util
import jax.numpy as jnp
import tensorflow as tf
import typing_extensions
from typing_extensions import Protocol
# Feature dictionary. Arbitrary nested dictionary with string keys and
# tf.Tensor as leaves.
Tensor = Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor]
# TFDS allows for nested `Features` ...
Features = Dict[str, Union[Tensor, "Features"]]
# ... but it's usually a better idea NOT to nest them. Also better for PyType.
FlatFeatures = Dict[str, Tensor]
D = TypeVar("D", FlatFeatures, tf.data.Dataset)
# Feature name for the random seed for tf.random.stateless_* ops. By
# convention ops should split of their random seed and keep the SEED_KEY
# feature:
# ```
# features[SEEQ_KEY], seed = tf.unstack(
# tf.random.experimental.stateless_split(features[SEED_KEY]))
# ````
SEED_KEY = "_seed"
# Regex that finds upper case characters.
_CAMEL_CASE_RGX = re.compile(r"(?<!^)(?=[A-Z])")
@typing_extensions.runtime_checkable
class PreprocessOp(Protocol):
"""Interface for all preprocess functions that transform `Features`.
You don't have to inherit from this protocol. Your class only needs to provide
the same function signature for __call__().
While not strictly required we strongly recommend annotating the preprocess
ops with `@dataclasses.dataclass(frozen=True)`. This shortens the code and
creates a nice __str__().
get_all_ops() will only return dataclasses but all other methods work with
any class implementing this protocol.
"""
def __call__(self, features: Features) -> Features:
"""Applies the preprocessing op to the features."""
# Deprecated. Please use `grain.tensorflow.MapTransform`.
class MapTransform(abc.ABC):
"""Base class for transformations of single elements.
This class implements the PreprocessOp interface and also:
- Limits the features to a flat dictionary (instead of an arbitrary nested
dictionary).
- Provides a convenient implementation of `__call__` that can automatically
apply the single transformation to a single example (`FlatFeatures`) or a
`tf.data.Dataset`. The latter is convenient for SeqIO users migrating to
preprocess_spec.py. For multiple transformations we still recommend users
to use the `PreprocessFn` class.
- Enforces subclasses to be a dataclasses.
"""
def __new__(cls, *args, **kwargs):
del args, kwargs
# Check that our subclass instance is a dataclass. We cannot do this with
# `__init_subclass__`` because the dataclasses.dataclass decorator wraps
# the intermediate class which is a subclass of MapTransform but not a
# dataclass.
if not dataclasses.is_dataclass(cls):
raise ValueError(
f"Class {cls} is not a dataclass. We strongly recommend annotating "
"transformations with `@dataclasses.dataclass(frozen=True)`.")
return super().__new__(cls)
def __call__(self, features: D) -> D:
"""Applies the transformation to the features or the dataset."""
logging.warning("clu.preprocess_spec.MapTransform is deprecated. Please "
"switch to grain.tensorflow.MapTransform.")
if isinstance(features, tf.data.Dataset):
return features.map(self._transform, num_parallel_calls=tf.data.AUTOTUNE)
return self._transform(features)
@abc.abstractmethod
def _transform(self, features: FlatFeatures) -> FlatFeatures:
"""Transforms the features."""
# Deprecated. Please use `grain.tensorflow.RandomMapTransform`.
class RandomMapTransform(MapTransform, abc.ABC):
"""Base class for random transformations of single elements.
We require all random transformations to use stateless random operations (e.g.
`tf.random.stateless_uniform()`) and respect the provided random seed. The
user can expect the random seed to be unique for the element.
If multiple random seeds are required the user can split the seed into N
new seeds:
```
seeds = tf.unstack(tf.random.experimental.stateless_split(seed, N))
```
"""
def __call__(self, features: D) -> D:
logging.warning("clu.preprocess_spec.RandomMapTransform is deprecated. "
"Please switch to grain.tensorflow.RandomMapTransform.")
if isinstance(features, tf.data.Dataset):
return features.map(self, num_parallel_calls=tf.data.AUTOTUNE)
next_seed, seed = tf.unstack(
tf.random.experimental.stateless_split(features.pop(SEED_KEY)))
features = self._transform(features, seed)
features[SEED_KEY] = next_seed
return features
@abc.abstractmethod
def _transform(self, features: FlatFeatures, seed: tf.Tensor) -> FlatFeatures: # pytype: disable=signature-mismatch # overriding-parameter-count-checks
"""Transforms the features only using stateless random ops."""
# Deprecated. Please use `grain.tensorflow.FilterMapTransform`.
class FilterTransform(abc.ABC):
def __call__(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
logging.warning("clu.preprocess_spec.FilterTransform is deprecated. Please "
"switch to grain.tensorflow.FilterTransform.")
return dataset.filter(self._predicate)
@abc.abstractmethod
def _predicate(self, features: FlatFeatures) -> tf.Tensor:
"""Returns a True if the element should be kept."""
def get_all_ops(module_name: str) -> List[Tuple[str, Type[PreprocessOp]]]:
"""Helper to return all preprocess ops in a module.
Modules that define processing ops can simply define:
all_ops = lambda: process_spec.get_all_ops(__name__)
all_ops() will then return a list with all dataclasses implementing the
PreprocessOp protocol.
Args:
module_name: Name of the module. The module must already be imported.
Returns:
List of tuples of process ops. The first tuple element is the class name
converted to snake case (MyAwesomeTransform => my_awesome_transform) and
the second element is the class.
"""
def is_op(x):
return (inspect.isclass(x) and dataclasses.is_dataclass(x) and
issubclass(x, PreprocessOp))
op_name = lambda n: _CAMEL_CASE_RGX.sub("_", n).lower()
members = inspect.getmembers(sys.modules[module_name])
return [(op_name(name), op) for name, op in members if is_op(op)]
def _jax_supported_tf_types():
types = [
x for _, x in inspect.getmembers(tf.dtypes)
if isinstance(x, tf.dtypes.DType) and hasattr(jnp, x.name)
]
# bool is called bool_ in jax and won't be found by the expression above.
return types + [tf.bool]
@dataclasses.dataclass
class OnlyJaxTypes:
"""Removes all features which types are not supported by JAX.
This filters dense tensors by dtype and removes sparse and ragged tensors.
The latter don't have an equivalent in JAX.
Attr:
types: List of allowed types. Defaults to all TF types that can be have an
equivalant in jax.numpy.
"""
types: List[tf.dtypes.DType] = dataclasses.field(
default_factory=_jax_supported_tf_types)
def __call__(self, features: Features) -> Features:
features = traverse_util.flatten_dict(features)
for name in list(features):
dtype = features[name].dtype
if dtype not in self.types:
del features[name]
logging.warning(
"Removing feature %r because dtype %s is not supported in JAX.",
name, dtype)
elif isinstance(features[name], tf.SparseTensor):
del features[name]
logging.warning(
"Removing feature %r because sparse tensors are not "
"supported in JAX.", name)
elif isinstance(features[name], tf.RaggedTensor):
del features[name]
logging.warning(
"Removing feature %r because ragged tensors are not support in "
"JAX.", name)
features = traverse_util.unflatten_dict(features)
return features # pytype: disable=bad-return-type
@dataclasses.dataclass
class PreprocessFn:
"""Chain of preprocessing ops combined to a single preprocessing function.
Attributes:
ops: List of feature transformations. Transformations will be applied in the
given order.
only_jax_types: If True will add the `OnlyJaxTypes` transformation at the
end.
"""
ops: Sequence[PreprocessOp]
only_jax_types: bool
def __call__(self, features: Features) -> Features:
"""Sequentially applies all `self.ops` and returns the result."""
logging.info("Features before preprocessing: %s",
_describe_features(features))
features = features.copy()
for op in self.ops:
features = op(features)
logging.info("Features after op %s:\n%s", op,
_describe_features(features))
logging.info("Features after preprocessing: %s",
_describe_features(features))
if self.only_jax_types:
features = OnlyJaxTypes()(features)
return features
def __add__(self, other: "PreprocessFn") -> "PreprocessFn":
"""Concatenates two `PreprocessingFn`."""
if not isinstance(other, PreprocessFn):
raise ValueError("Can only add other instances of `PreprocessFn`.")
return PreprocessFn(
ops=tuple(self.ops) + tuple(other.ops),
only_jax_types=self.only_jax_types or other.only_jax_types,
)
def __getitem__(self, op_index: Union[int, slice]) -> "PreprocessFn":
"""Returns a `PreprocessFn` of the sliced ops."""
return PreprocessFn(
ops=self.ops[op_index]
if isinstance(op_index, slice) else [self.ops[op_index]],
only_jax_types=self.only_jax_types,
)
def _get_op_class(
expr: List[ast.stmt],
available_ops: Dict[str, Type[PreprocessOp]]) -> Type[PreprocessOp]:
"""Gets the process op fn from the given expression."""
if isinstance(expr, ast.Call):
fn_name = expr.func.id
elif isinstance(expr, ast.Name):
fn_name = expr.id
else:
raise ValueError(
f"Could not parse function name from expression: {expr!r}.")
if fn_name in available_ops:
return available_ops[fn_name]
raise ValueError(
f"'{fn_name}' is not available (available ops: {list(available_ops)}).")
def _parse_single_preprocess_op(
spec: str, available_ops: Dict[str, Type[PreprocessOp]]) -> PreprocessOp:
"""Parsing the spec for a single preprocess op.
The op can just be the method name or the method name followed by any
arguments (both positional and keyword) to the method.
See the test cases for some valid examples.
Args:
spec: String specifying a single processing operations.
available_ops: Available preprocessing ops.
Returns:
The ProcessOp corresponding to the spec.
"""
try:
expr = ast.parse(spec, mode="eval").body # pytype: disable=attribute-error
except SyntaxError as e:
raise ValueError(f"{spec!r} is not a valid preprocess op spec.") from e
op_class = _get_op_class(expr, available_ops) # pytype: disable=wrong-arg-types
# Simple case without arguments.
if isinstance(expr, ast.Name):
return op_class()
assert isinstance(expr, ast.Call)
args = [ast.literal_eval(arg) for arg in expr.args]
kwargs = {kv.arg: ast.literal_eval(kv.value) for kv in expr.keywords}
if not args:
return op_class(**kwargs)
# Translate positional arguments into keyword arguments.
available_arg_names = [f.name for f in dataclasses.fields(op_class)]
for i, arg in enumerate(args):
name = available_arg_names[i]
if name in kwargs:
raise ValueError(
f"Argument {name} to {op_class} given both as positional argument "
f"(value: {arg}) and keyword argument (value: {kwargs[name]}).")
kwargs[name] = arg
return op_class(**kwargs)
def parse(spec: str,
available_ops: List[Tuple[str, Type[PreprocessOp]]],
*,
only_jax_types: bool = True) -> PreprocessFn:
"""Parses a preprocess spec; a '|' separated list of preprocess ops."""
available_ops = dict(available_ops)
if not spec.strip():
ops = []
else:
ops = [
_parse_single_preprocess_op(s, available_ops) for s in spec.split("|")
]
return PreprocessFn(ops, only_jax_types=only_jax_types)
def _describe_features(features: Features) -> str:
description = {}
for k, v in features.items():
if isinstance(v, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
description[k] = f"{v.dtype.name}{v.shape}"
elif isinstance(v, dict):
description[k] = _describe_features(v)
else:
description[k] = f"Unsupported type {type(v)} at feature '{k}'."
return str(description)