/
stats_generator.py
478 lines (369 loc) · 16 KB
/
stats_generator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for statistics generators."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
from typing import Any, Dict, Generic, Hashable, Iterable, List, Optional, Text, TypeVar
import apache_beam as beam
import pyarrow as pa
from tensorflow_data_validation import types
from tensorflow_data_validation.statistics.generators import input_batch
from tensorflow_metadata.proto.v0 import schema_pb2
from tensorflow_metadata.proto.v0 import statistics_pb2
class StatsGenerator(object):
"""Generate statistics."""
def __init__(self, name: Text,
schema: Optional[schema_pb2.Schema] = None) -> None:
"""Initializes a statistics generator.
Args:
name: A unique name associated with the statistics generator.
schema: An optional schema for the dataset.
"""
self._name = name
self._schema = schema
@property
def name(self):
return self._name
@property
def schema(self):
return self._schema
# Have a type variable to represent the type of the accumulator
# in a combiner stats generator.
ACCTYPE = TypeVar('ACCTYPE')
class CombinerStatsGenerator(Generic[ACCTYPE], StatsGenerator):
"""A StatsGenerator which computes statistics using a combiner function.
This class computes statistics using a combiner function. It emits partial
states processing a batch of examples at a time, merges the partial states,
and finally computes the statistics from the merged partial state at the end.
This object mirrors a beam.CombineFn except for the add_input interface, which
is expected to be defined by its sub-classes. Specifically, the generator
must implement the following four methods:
Initializes an accumulator to store the partial state and returns it.
create_accumulator()
Incorporates a batch of input examples (represented as an arrow RecordBatch)
into the current accumulator and returns the updated accumulator.
add_input(accumulator, input_record_batch)
Merge the partial states in the accumulators and returns the accumulator
containing the merged state.
merge_accumulators(accumulators)
Compute statistics from the partial state in the accumulator and
return the result as a DatasetFeatureStatistics proto.
extract_output(accumulator)
"""
# TODO(b/176939874): Investigate which stats generators will benefit from
# setup.
def setup(self) -> None:
"""Prepares an instance for combining.
Subclasses should put costly initializations here instead of in
__init__(), so that 1) the cost is properly recognized by Beam as
setup cost (per worker) and 2) the cost is not paid at the pipeline
construction time.
"""
pass
def create_accumulator(self) -> ACCTYPE:
"""Returns a fresh, empty accumulator.
Returns:
An empty accumulator.
"""
raise NotImplementedError
def add_input(self, accumulator: ACCTYPE,
input_record_batch: pa.RecordBatch) -> ACCTYPE:
"""Returns result of folding a batch of inputs into accumulator.
Args:
accumulator: The current accumulator.
input_record_batch: An Arrow RecordBatch whose columns are features and
rows are examples. The columns are of type List<primitive> or Null (If a
feature's value is None across all the examples in the batch, its
corresponding column is of Null type).
Returns:
The accumulator after updating the statistics for the batch of inputs.
"""
raise NotImplementedError
def merge_accumulators(self, accumulators: Iterable[ACCTYPE]) -> ACCTYPE:
"""Merges several accumulators to a single accumulator value.
Note: mutating any element in `accumulators` is not allowed and will result
in undefined behavior.
Args:
accumulators: The accumulators to merge.
Returns:
The merged accumulator.
"""
raise NotImplementedError
# TODO(b/176939874): Investigate which stats generators will benefit from
# compact.
def compact(self, accumulator: ACCTYPE) -> ACCTYPE:
"""Returns a compact representation of the accumulator.
This is optionally called before an accumulator is sent across the wire. The
base class is a no-op. This may be overwritten by the derived class.
Args:
accumulator: The accumulator to compact.
Returns:
The compacted accumulator. By default is an identity.
"""
return accumulator
def extract_output(
self, accumulator: ACCTYPE) -> statistics_pb2.DatasetFeatureStatistics:
"""Returns result of converting accumulator into the output value.
Args:
accumulator: The final accumulator value.
Returns:
A proto representing the result of this stats generator.
"""
raise NotImplementedError
# TODO(b/176939874): Add teardown() to all StatsGenerators if/when it is
# needed.
class CombinerFeatureStatsGenerator(Generic[ACCTYPE], StatsGenerator):
"""Generate feature level statistics using combiner function.
This interface is a simplification of CombinerStatsGenerator for the special
case of statistics that do not require cross-feature computations. It mirrors
a beam.CombineFn for the values of a specific feature.
"""
def setup(self) -> None:
"""Prepares an instance for combining.
Subclasses should put costly initializations here instead of in
__init__(), so that 1) the cost is properly recognized by Beam as
setup cost (per worker) and 2) the cost is not paid at the pipeline
construction time.
"""
pass
def create_accumulator(self) -> ACCTYPE:
"""Returns a fresh, empty accumulator.
Returns:
An empty accumulator.
"""
raise NotImplementedError
def add_input(self, accumulator: ACCTYPE, feature_path: types.FeaturePath,
feature_array: pa.Array) -> ACCTYPE:
"""Returns result of folding a batch of inputs into accumulator.
Args:
accumulator: The current accumulator.
feature_path: The path of the feature.
feature_array: An arrow Array representing a batch of feature values
which should be added to the accumulator.
Returns:
The accumulator after updating the statistics for the batch of inputs.
"""
raise NotImplementedError
def merge_accumulators(self, accumulators: Iterable[ACCTYPE]) -> ACCTYPE:
"""Merges several accumulators to a single accumulator value.
Args:
accumulators: The accumulators to merge.
Returns:
The merged accumulator.
"""
raise NotImplementedError
def compact(self, accumulator: ACCTYPE) -> ACCTYPE:
"""Returns a compact representation of the accumulator.
This is optionally called before an accumulator is sent across the wire. The
base class is a no-op. This may be overwritten by the derived class.
Args:
accumulator: The accumulator to compact.
Returns:
The compacted accumulator. By default is an identity.
"""
return accumulator
def extract_output(
self, accumulator: ACCTYPE) -> statistics_pb2.FeatureNameStatistics:
"""Returns result of converting accumulator into the output value.
Args:
accumulator: The final accumulator value.
Returns:
A proto representing the result of this stats generator.
"""
raise NotImplementedError
CONSTITUENT_ACCTYPE = TypeVar('CONSTITUENT_ACCTYPE')
class ConstituentStatsGenerator(
Generic[CONSTITUENT_ACCTYPE], metaclass=abc.ABCMeta):
"""A stats generator meant to be used as a part of a composite generator.
A constituent stats generator facilitates sharing logic between several stats
generators. It is functionally identical to a beam.CombineFn, but it expects
add_input to be called with instances of InputBatch.
"""
def setup(self) -> None:
"""Prepares this constituent generator.
Subclasses should put costly initializations here instead of in
__init__(), so that 1) the cost is properly recognized by Beam as
setup cost (per worker) and 2) the cost is not paid at the pipeline
construction time.
"""
pass
@classmethod
@abc.abstractmethod
def key(cls) -> Hashable:
"""A class method which returns an ID for instances of this stats generator.
This method should take all the arguments to the __init__ method so that the
result of ConstituentStatsGenerator.key(*init_args) is identical to
ConstituentStatsGenerator(*init_args).key(). This allows a
CompositeStatsGenerator to construct a specific constituent generator in its
__init__, and then recover the corresonding output value in its
extract_composite_output method.
Returns:
A unique ID for instances of this stats generator class.
"""
@abc.abstractmethod
def get_key(self) -> Hashable:
"""Returns the ID of this specific generator.
Returns:
A unique ID for this stats generator class instance.
"""
@abc.abstractmethod
def create_accumulator(self) -> CONSTITUENT_ACCTYPE:
"""Returns a fresh, empty accumulator.
Returns:
An empty accumulator.
"""
@abc.abstractmethod
def add_input(self, accumulator: CONSTITUENT_ACCTYPE,
batch: input_batch.InputBatch) -> CONSTITUENT_ACCTYPE:
"""Returns result of folding a batch of inputs into accumulator.
Args:
accumulator: The current accumulator.
batch: An InputBatch which wraps an Arrow RecordBatch whose columns are
features and rows are examples. The columns are of type List<primitive>
or Null (If a feature's value is None across all the examples in the
batch, its corresponding column is of Null type).
Returns:
The accumulator after updating the statistics for the batch of inputs.
"""
@abc.abstractmethod
def merge_accumulators(
self, accumulators: Iterable[CONSTITUENT_ACCTYPE]) -> CONSTITUENT_ACCTYPE:
"""Merges several accumulators to a single accumulator value.
Args:
accumulators: The accumulators to merge.
Returns:
The merged accumulator.
"""
def compact(self, accumulator: CONSTITUENT_ACCTYPE) -> CONSTITUENT_ACCTYPE:
"""Returns a compact representation of the accumulator.
This is optionally called before an accumulator is sent across the wire. The
base class is a no-op. This may be overwritten by the derived class.
Args:
accumulator: The accumulator to compact.
Returns:
The compacted accumulator.
"""
return accumulator
@abc.abstractmethod
def extract_output(self, accumulator: CONSTITUENT_ACCTYPE) -> Any:
"""Returns result of converting accumulator into the output value.
Args:
accumulator: The final accumulator value.
Returns:
The final output value which should be used by composite generators which
use this constituent generator.
"""
class CompositeStatsGenerator(CombinerStatsGenerator,
Generic[CONSTITUENT_ACCTYPE]):
"""A combiner generator built from ConstituentStatsGenerators.
Typical usage involves overriding the __init__, to provide a set of
constituent generators, and extract_composite_output, to process the outputs
of those constituent generators. As a toy example, consider:
class ExampleCompositeStatsGenerator(
stats_generator.CompositeStatsGenerator):
def __init__(self,
schema: schema_pb2.Schema,
name: Text = 'ExampleCompositeStatsGenerator'
) -> None:
# custom logic to build the set of relevant constituents
self._paths = [types.FeaturePath(['f1']), types.FeaturePath(['f2'])]
constituents = [CountMissingCombiner(p) for p in self._paths]
# call super class init with constituents
super(ExampleCompositeStatsGenerator, self).__init__(
name, constituents, schema)
def extract_composite_outputs(self, accumulator):
# custom logic to convert constituent outputs to stats proto
stats = statistics_pb2.DatasetFeatureStatistics()
for path in self._paths:
# lookup output from a particular combiner using the key() function,
# which typically takes the same args as __init__.
num_missing = accumulator[CountMissingCombiner.key(path)]
stats.features.add(path=path).custom_stats.add(
name='num_missing', num=count_missing)
This class is very similar to the SingleInputTupleCombineFn and adds two small
features:
1) The input value passed to add_inputs is wrapped in an InputBatch object
before being passed on to the constituent generators.
2) The API for providing constituents and retrieving their outputs is a dict
rather than a tuple, which makes it easier to keep track of which output
came from which constituent generator.
"""
def __init__(self, name: Text,
constituents: Iterable[ConstituentStatsGenerator],
schema: Optional[schema_pb2.Schema]) -> None:
super(CompositeStatsGenerator, self).__init__(name, schema)
self._keys, self._constituents = zip(*(
(c.get_key(), c) for c in constituents))
def setup(self):
for c in self._constituents:
c.setup()
def create_accumulator(self) -> List[CONSTITUENT_ACCTYPE]:
return [c.create_accumulator() for c in self._constituents]
def add_input(
self, accumulator: List[CONSTITUENT_ACCTYPE],
input_record_batch: pa.RecordBatch) -> List[CONSTITUENT_ACCTYPE]:
batch = input_batch.InputBatch(input_record_batch)
return [
c.add_input(a, batch) for c, a in zip(self._constituents, accumulator)
]
def merge_accumulators(
self, accumulators: Iterable[List[CONSTITUENT_ACCTYPE]]
) -> List[CONSTITUENT_ACCTYPE]:
return [
c.merge_accumulators(a)
for c, a in zip(self._constituents, zip(*accumulators))
]
def compact(
self,
accumulator: List[CONSTITUENT_ACCTYPE]) -> List[CONSTITUENT_ACCTYPE]:
return [c.compact(a) for c, a in zip(self._constituents, accumulator)]
def extract_output(
self, accumulator: List[CONSTITUENT_ACCTYPE]
) -> statistics_pb2.DatasetFeatureStatistics:
return self.extract_composite_output(
dict(
zip(self._keys,
(c.extract_output(a)
for c, a in zip(self._constituents, accumulator)))))
def extract_composite_output(
self, accumulator: Dict[Text,
Any]) -> statistics_pb2.DatasetFeatureStatistics:
"""Extracts output from a dict of outputs for each constituent combiner.
Args:
accumulator: A dict mapping from combiner keys to the corresponding output
for that combiner.
Returns:
A proto representing the result of this stats generator.
"""
raise NotImplementedError()
class TransformStatsGenerator(StatsGenerator):
"""A StatsGenerator which wraps an arbitrary Beam PTransform.
This class computes statistics using a user-provided Beam PTransform. The
PTransform must accept a Beam PCollection where each element is a tuple
containing a slice key and an Arrow RecordBatch representing a batch of
examples. It must return a PCollection where each element is a tuple
containing a slice key and a DatasetFeatureStatistics proto representing the
statistics of a slice.
"""
def __init__(self,
name: Text,
ptransform: beam.PTransform,
schema: Optional[schema_pb2.Schema] = None) -> None:
self._ptransform = ptransform
super(TransformStatsGenerator, self).__init__(name, schema)
@property
def ptransform(self):
return self._ptransform