-
Notifications
You must be signed in to change notification settings - Fork 213
/
analyzers.py
542 lines (466 loc) · 23.5 KB
/
analyzers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
# Copyright 2021 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental functions that involve a full pass over the dataset.
This module contains functions that are used in the preprocessing function, to
define a full pass operation such as computing the sum, min, max or unique
values of a tensor over the entire dataset. This is implemented by a reduction
operation in the Beam implementation.
From the user's point of view, an analyzer appears as a regular TensorFlow
function, i.e. it accepts and returns tensors. However it is represented in
the graph as a `Analyzer` which is not a TensorFlow op, but a placeholder for
the computation that takes place outside of TensorFlow.
"""
from typing import Any, Collection, List, Optional, Tuple, Union, Iterable
import numpy as np
import pyarrow as pa
import tensorflow as tf
from tensorflow_transform import analyzer_nodes
from tensorflow_transform import analyzers
from tensorflow_transform import common
from tensorflow_transform import common_types
from tensorflow_transform import nodes
from tensorflow_transform import tf_utils
from tfx_bsl import sketches
# TODO(b/243513856): Switch to `collections.namedtuple` or `typing.NamedTuple`
# once the Spark issue is resolved.
from tfx_bsl.types import tfx_namedtuple
from typing_extensions import Protocol
__all__ = [
'PTransformAnalyzerCacheCoder',
'SimpleJsonPTransformAnalyzerCacheCoder',
'CacheablePTransformAnalyzer',
'ptransform_analyzer',
'approximate_vocabulary',
]
PTransformAnalyzerCacheCoder = analyzer_nodes.CacheCoder
SimpleJsonPTransformAnalyzerCacheCoder = analyzer_nodes.JsonNumpyCacheCoder
_APPROXIMATE_VOCAB_FILENAME_PREFIX = 'approx_vocab_'
_APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX = 'approx_vocab_frequency_'
class _BeamPTransform(Protocol):
"""Pytype for `beam.PTransform` without depending on beam in this module.
"""
def expand(self, pcol: Any) -> Any:
...
def default_label(self) -> str:
...
# TODO(zoyahav): Add an example for using this API.
class CacheablePTransformAnalyzer(
tfx_namedtuple.TypedNamedTuple(
'PTransformCachedAnalyzer',
[('make_accumulators_ptransform', _BeamPTransform),
('merge_accumulators_ptransform', _BeamPTransform),
('extract_output_ptransform', _BeamPTransform),
('cache_coder', PTransformAnalyzerCacheCoder)])):
"""A PTransformAnalyzer which enables analyzer cache.
WARNING: This should only be used if the analyzer can correctly be separated
into make_accumulators, merge_accumulators and extract_output stages.
1. make_accumulators_ptransform: this is a `beam.PTransform` which maps data
to a more compact mergeable representation (accumulator). Mergeable here
means that it is possible to combine multiple representations produced from
a partition of the dataset into a representation of the entire dataset.
1. merge_accumulators_ptransform: this is a `beam.PTransform` which operates
on a collection of accumulators, i.e. the results of both the
make_accumulators_ptransform and merge_accumulators_ptransform stages,
and produces a single reduced accumulator. This operation must be
associative and commutative in order to have reliably reproducible results.
1. extract_output: this is a `beam.PTransform` which operates on the result of
the merge_accumulators_ptransform stage, and produces the outputs of the
analyzer. These outputs must be consistent with the `output_dtypes` and
`output_shapes` provided to `ptransform_analyzer`.
This container also holds a `cache_coder` (`PTransformAnalyzerCacheCoder`)
which can encode outputs and decode the inputs of the
`merge_accumulators_ptransform` stage.
In many cases, `SimpleJsonPTransformAnalyzerCacheCoder` would be sufficient.
To ensure the correctness of this analyzer, the following must hold:
merge(make({D1, ..., Dn})) == merge({make(D1), ..., make(Dn)})
"""
__slots__ = ()
def _apply_analyzer(ptransform: Union[_BeamPTransform,
CacheablePTransformAnalyzer],
*tensor_inputs: common_types.TensorType,
**analyzer_def_kwargs: Any) -> Tuple[tf.Tensor, ...]:
"""Applies the analyzer over the whole dataset.
Args:
ptransform: A class inheriting from analyzer_nodes.AnalyzerDef or
CacheablePTransformAnalyzer that should be applied.
*tensor_inputs: A list of input `Tensor`s, `SparseTensor`s, or
`RaggedTensor`s.
**analyzer_def_kwargs: KW arguments to use when constructing
analyzer_def_cls.
Returns:
A list of `Tensor`s representing the values of the analysis result.
"""
input_values_node = analyzer_nodes.get_input_tensors_value_nodes(
tensor_inputs)
if isinstance(ptransform, CacheablePTransformAnalyzer):
with tf.compat.v1.name_scope('make_accumulators'):
make_accumulators_value_node = nodes.apply_multi_output_operation(
analyzer_nodes.PTransform,
input_values_node,
ptransform=ptransform.make_accumulators_ptransform,
is_partitionable=True,
**analyzer_def_kwargs)
with tf.compat.v1.name_scope('local_merge_accumulators'):
cached_value_nodes = nodes.apply_multi_output_operation(
analyzer_nodes.PTransform,
*make_accumulators_value_node,
ptransform=ptransform.merge_accumulators_ptransform,
is_partitionable=True,
cache_coder=ptransform.cache_coder,
**analyzer_def_kwargs)
with tf.compat.v1.name_scope('global_merge_accumulators'):
merge_output_value_nodes = nodes.apply_multi_output_operation(
analyzer_nodes.PTransform,
*cached_value_nodes,
ptransform=ptransform.merge_accumulators_ptransform,
is_partitionable=False,
**analyzer_def_kwargs)
with tf.compat.v1.name_scope('extract_output'):
output_value_nodes = nodes.apply_multi_output_operation(
analyzer_nodes.PTransform,
*merge_output_value_nodes,
ptransform=ptransform.extract_output_ptransform,
is_partitionable=False,
**analyzer_def_kwargs)
else:
output_value_nodes = nodes.apply_multi_output_operation(
analyzer_nodes.PTransform,
input_values_node,
ptransform=ptransform,
is_partitionable=False,
**analyzer_def_kwargs)
return tuple(map(analyzer_nodes.wrap_as_tensor, output_value_nodes))
# TODO(b/164921571): Support output assets in tfrecord format.
@common.log_api_use(common.ANALYZER_COLLECTION)
def ptransform_analyzer(
inputs: Collection[tf.Tensor],
ptransform: Union[_BeamPTransform, CacheablePTransformAnalyzer],
output_dtypes: Collection[tf.dtypes.DType],
output_shapes: Collection[List[int]],
output_asset_default_values: Optional[Collection[Optional[bytes]]] = None,
name: Optional[str] = None):
# pylint: disable=line-too-long
"""Applies a user-provided PTransform over the whole dataset.
WARNING: This is experimental.
Note that in order to have asset files copied correctly, any outputs that
represent asset filenames must be added to the `tf.GraphKeys.ASSET_FILEPATHS`
collection by the caller if using Transform's APIs in compat v1 mode.
Example:
>>> class MeanPerKey(beam.PTransform):
... def expand(self, pcoll: beam.PCollection[Tuple[np.ndarray, np.ndarray]]) -> Tuple[beam.PCollection[np.ndarray], beam.PCollection[np.ndarray]]:
... def extract_output(key_value_pairs):
... keys, values = zip(*key_value_pairs)
... return [beam.TaggedOutput('keys', keys),
... beam.TaggedOutput('values', values)]
... return tuple(
... pcoll
... | 'ZipAndFlatten' >> beam.FlatMap(lambda batches: list(zip(*batches)))
... | 'MeanPerKey' >> beam.CombinePerKey(beam.combiners.MeanCombineFn())
... | 'ToList' >> beam.combiners.ToList()
... | 'Extract' >> beam.FlatMap(extract_output).with_outputs(
... 'keys', 'values'))
>>> def preprocessing_fn(inputs):
... outputs = tft.experimental.ptransform_analyzer(
... inputs=[inputs['s'], inputs['x']],
... ptransform=MeanPerKey(),
... output_dtypes=[tf.string, tf.float32],
... output_shapes=[[2], [2]])
... (keys, means) = outputs
... mean_a = tf.reshape(tf.gather(means, tf.where(keys == 'a')), [])
... return { 'x/mean_a': inputs['x'] / mean_a }
>>> raw_data = [dict(x=1, s='a'), dict(x=8, s='b'), dict(x=3, s='a')]
>>> feature_spec = dict(
... x=tf.io.FixedLenFeature([], tf.float32),
... s=tf.io.FixedLenFeature([], tf.string))
>>> raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec)
>>> with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
... transformed_dataset, transform_fn = (
... (raw_data, raw_data_metadata)
... | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
>>> transformed_data, transformed_metadata = transformed_dataset
>>> transformed_data
[{'x/mean_a': 0.5}, {'x/mean_a': 4.0}, {'x/mean_a': 1.5}]
Args:
inputs: An ordered collection of input `Tensor`s.
ptransform: A Beam PTransform that accepts a Beam PCollection where each
element is a tuple of `ndarray`s. Each element in the tuple contains a
batch of values for the corresponding input tensor of the analyzer and
maintain their shapes and dtypes.
It returns a `PCollection`, or a tuple of `PCollections`, each containing
a single element which is an `ndarray` or a list of primitive types. The
contents of these output `PCollection`s must be consistent with the given
values of `output_dtypes` and `output_shapes`.
It may inherit from `tft_beam.experimental.PTransformAnalyzer` if access
to a temp base directory is needed.
Alternatively, it could be an instance of
`tft.experimental.CacheablePTransformAnalyzer` in order to enable cache
for this analyzer, when analyzer cache is enabled for this pipeline.
output_dtypes: An ordered collection of TensorFlow dtypes of the output of
the analyzer.
output_shapes: An ordered collection of shapes of the output of the
analyzer. Must have the same length as output_dtypes.
output_asset_default_values: (Optional) An ordered collection of optional
`bytes` aligned with output_dtypes/output_shapes. Every item in this
collection which is not `None` indicates that the output is a TF asset
path, and its value would be used as the default value of this asset file
prior to analysis.
name: (Optional) Similar to a TF op name. Used to define a unique scope for
this analyzer, which can be used for debugging info.
Returns:
A list of output `Tensor`s. These will have `dtype` and `shape` as
specified by `output_dtypes` and `output_shapes`.
Raises:
ValueError: If output_dtypes and output_shapes have different lengths.
"""
# pylint: enable=line-too-long
if len(output_dtypes) != len(output_shapes):
raise ValueError('output_dtypes ({}) and output_shapes ({}) had different'
' lengths'.format(output_dtypes, output_shapes))
if output_asset_default_values is not None:
if len(output_asset_default_values) != len(output_dtypes):
raise ValueError(
'output_dtypes ({}) and output_asset_default_values ({}) had '
'different lengths'.format(output_dtypes,
output_asset_default_values))
output_asset_default_values = [
analyzer_nodes.TemporaryAssetInfo(value, 'text')
for value in output_asset_default_values
]
else:
output_asset_default_values = [None] * len(output_dtypes)
with tf.compat.v1.name_scope(name, 'ptransform'):
output_tensor_infos = [
analyzer_nodes.TensorInfo(dtype, shape, default_asset_content)
for dtype, shape, default_asset_content in zip(
output_dtypes, output_shapes, output_asset_default_values)
]
return _apply_analyzer(
ptransform, *inputs, output_tensor_info_list=output_tensor_infos)
def _get_approx_vocab_filename(vocab_filename: Optional[str],
store_frequency: bool) -> str:
"""Returns a sanitized vocabulary filename with appropriate prefix applied.
Args:
vocab_filename: The file name for the approximate vocabulary file. If None,
the "approximate_vocabulary" scope name in the context of this graph will
be used as the file name.
store_frequency: A bool that is true when the vocabulary for which this
generates a filename stores term frequency. False otherwise.
Returns:
A valid filename.
"""
if vocab_filename is not None:
prefix = None
elif store_frequency:
prefix = _APPROXIMATE_VOCAB_FILENAME_PREFIX
else:
prefix = _APPROXIMATE_VOCAB_FREQUENCY_FILENAME_PREFIX
# Make the file name path safe.
return analyzers.sanitized_vocab_filename(vocab_filename, prefix=prefix)
@common.log_api_use(common.ANALYZER_COLLECTION)
def approximate_vocabulary(
x: common_types.TensorType,
top_k: int,
vocab_filename: Optional[str] = None,
store_frequency: bool = False,
weights: Optional[tf.Tensor] = None,
file_format: common_types.VocabularyFileFormatType = analyzers
.DEFAULT_VOCABULARY_FILE_FORMAT,
name: Optional[str] = None) -> common_types.TemporaryAnalyzerOutputType:
r"""Computes the unique values of a `Tensor` over the whole dataset.
Approximately computes the unique values taken by `x`, which can be a
`Tensor`, `SparseTensor`, or `RaggedTensor` of any size. The unique values
will be aggregated over all dimensions of `x` and all instances.
This analyzer provides an approximate alternative to `tft.vocabulary` that can
be more efficient with smaller `top_k` and/or smaller number of unique
elements in `x`. As a rule of thumb, `approximate_vocabulary` becomes more
efficient than `tft.vocabulary` if `top_k` or the number of unique elements in
`x` is smaller than 2*10^5. Moreover, this analyzer is subject to combiner
packing optimization that does not apply to `tft.vocabulary`. Caching is also
more efficient with the approximate implementation since the filtration
happens before writing out cache. Output artifact of `approximate_vocabulary`
is consistent with `tft.vocabulary` and can be used in `tft.apply_vocabulary`
mapper.
Implementation of this analyzer is based on the Misra-Gries algorithm [1]. It
stores at most `top_k` elements with lower bound frequency estimates at a
time. The algorithm keeps track of the approximation error `delta` such that
for any item x with true frequency X:
frequency[x] <= X <= frequency[x] + delta,
delta <= (m - m') / (top_k + 1),
where m is the total frequency of the items in the dataset and m' is the sum
of the lower bound estimates in `frequency` [2]. For datasets that are Zipfian
distributed with parameter `a`, the algorithm provides an expected value of
delta = m / (top_k ^ a) [3].
[1]
https://www.cs.utexas.edu/users/misra/scannedPdf.dir/FindRepeatedElements.pdf
[2] http://www.cohenwang.com/edith/bigdataclass2013/lectures/lecture1.pdf
[3] http://dimacs.rutgers.edu/~graham/pubs/papers/countersj.pdf
In case `file_format` is 'text' and one of the tokens contains the '\n' or
'\r' characters or is empty it will be discarded.
If an integer `Tensor` is provided, its semantic type should be categorical
not a continuous/numeric, since computing a vocabulary over a continuous
feature is not appropriate.
The unique values are sorted by decreasing frequency and then reverse
lexicographical order (e.g. [('a', 5), ('c', 3), ('b', 3)]). This is true even
if `x` is numerical dtype (e.g. [('3', 5), ('2', 3), ('111', 3)]).
Args:
x: A categorical/discrete input `Tensor`, `SparseTensor`, or `RaggedTensor`
with dtype tf.string or tf.int[8|16|32|64].
top_k: Limit the generated vocabulary to the first `top_k` elements. Note
that if `top_k` is larger than the number of unique elements in `x`, then
the result will be exact.
vocab_filename: The file name for the vocabulary file. If None, a file name
will be chosen based on the current scope. If not None, should be unique
within a given preprocessing function. NOTE: To make your pipelines
resilient to implementation details please set `vocab_filename` when you
are using the vocab_filename on a downstream component.
store_frequency: If True, frequency of the words is stored in the vocabulary
file. Each line in the file will be of the form 'frequency word'. NOTE: if
this is True then the computed vocabulary cannot be used with
`tft.apply_vocabulary` directly, since frequencies are added to the
beginning of each row of the vocabulary, which the mapper will not
ignore.
weights: (Optional) Weights `Tensor` for the vocabulary. It must have the
same shape as x.
file_format: (Optional) A str. The format of the resulting vocabulary file.
Accepted formats are: 'tfrecord_gzip', 'text'. 'tfrecord_gzip' requires
tensorflow>=2.4. The default value is 'text'.
name: (Optional) A name for this operation.
Returns:
The path name for the vocabulary file containing the unique values of `x`.
Raises:
ValueError: If `top_k` is negative.
If `file_format` is not in the list of allowed formats.
If x.dtype is not string or integral.
"""
if top_k <= 0:
raise ValueError('top_k must be positive, but got: %r' % top_k)
elif top_k > analyzers.LARGE_VOCAB_TOP_K:
raise ValueError('Provided top_k threshold is too large for the '
'approximate calculation: if the expected number of '
'unique elements is larger than top_k, tft.vocabulary may '
'be more efficient. Maximum allowed top_k is {}'.format(
analyzers.LARGE_VOCAB_TOP_K))
if file_format not in analyzers.ALLOWED_VOCABULARY_FILE_FORMATS:
raise ValueError(
'"{}" is not an accepted file_format. It should be one of: {}'.format(
file_format, analyzers.ALLOWED_VOCABULARY_FILE_FORMATS))
if x.dtype != tf.string and not x.dtype.is_integer:
raise ValueError('expected tf.string or integer but got %r' % x.dtype)
with tf.compat.v1.name_scope(name, 'approximate_vocabulary'):
vocabulary_key = vocab_filename
vocab_filename = _get_approx_vocab_filename(vocab_filename, store_frequency)
analyzer_inputs = _get_approximate_vocabulary_analyzer_inputs(
x=x, file_format=file_format, weights=weights)
return _approximate_vocabulary_analyzer_nodes(
analyzer_inputs=analyzer_inputs,
input_dtype=x.dtype.name,
vocab_filename=vocab_filename,
top_k=top_k,
store_frequency=store_frequency,
file_format=file_format,
vocabulary_key=vocabulary_key)
def _approximate_vocabulary_analyzer_nodes(
analyzer_inputs: Collection[tf.Tensor], input_dtype: tf.dtypes.DType,
vocab_filename: str, top_k: int, store_frequency: bool,
file_format: common_types.VocabularyFileFormatType,
vocabulary_key: str) -> common_types.TemporaryAnalyzerOutputType:
"""Internal helper for analyzing vocab. See `vocabulary` doc string."""
# TODO(b/208879020): Add vocabulary size annotation for this analyzer.
analyzers.register_vocab(
vocab_filename, vocabulary_key=vocabulary_key, file_format=file_format)
outputs_value_nodes = analyzers.apply_cacheable_combine_operation(
_VocabularyCombiner(top_k, input_dtype), *analyzer_inputs)
flattened_outputs_value_node = nodes.apply_operation(
analyzer_nodes.FlattenLists, *outputs_value_nodes)
vocab_filename_node = nodes.apply_operation(
analyzer_nodes.VocabularyOrderAndWrite,
flattened_outputs_value_node,
vocab_filename=vocab_filename,
store_frequency=store_frequency,
input_dtype=input_dtype,
file_format=file_format,
fingerprint_shuffle=False,
input_is_sorted=True)
return analyzer_nodes.wrap_as_tensor(vocab_filename_node)
class _MisraGriesSketchCoder(analyzer_nodes.CacheCoder):
"""Cache coder for the approximate vocabulary accumulator."""
def encode_cache(self, accumulator: sketches.MisraGriesSketch) -> bytes:
return accumulator.Serialize()
def decode_cache(self,
encoded_accumulator: bytes) -> sketches.MisraGriesSketch:
return sketches.MisraGriesSketch.Deserialize(encoded_accumulator)
class _VocabularyCombiner(analyzer_nodes.Combiner):
"""Approximately computes unique values on the PCollection."""
def __init__(self, top_k: int, input_dtype: tf.dtypes.DType):
self._top_k = top_k
self._input_dtype = input_dtype
def create_accumulator(self) -> sketches.MisraGriesSketch:
return sketches.MisraGriesSketch(self._top_k)
def add_input(
self, accumulator: sketches.MisraGriesSketch,
next_input: Tuple[np.ndarray, np.ndarray]) -> sketches.MisraGriesSketch:
items, weights = next_input
if items.size:
accumulator.AddValues(pa.array(items), pa.array(weights, pa.float32()))
return accumulator
def merge_accumulators(
self, accumulators: Iterable[sketches.MisraGriesSketch]
) -> sketches.MisraGriesSketch:
# Make sure that `accumulators` is an iterator (so that the position is
# remembered).
accumulators = iter(accumulators)
result = next(accumulators)
for accumulator in accumulators:
result.Merge(accumulator)
return result
def extract_output(self,
accumulator: sketches.MisraGriesSketch) -> np.ndarray:
estimate = accumulator.Estimate()
estimate.validate()
result = np.dstack(reversed(estimate.flatten()))
if not result.size:
return np.array(
[[analyzers.get_empy_vocabulary_dummy_value(self._input_dtype)]],
dtype=object)
else:
return result
def output_tensor_infos(self) -> List[analyzer_nodes.TensorInfo]:
return [analyzer_nodes.TensorInfo(tf.string, [None, 2], None)]
@property
def accumulator_coder(self) -> _MisraGriesSketchCoder:
return _MisraGriesSketchCoder()
def _get_approximate_vocabulary_analyzer_inputs(
x: common_types.TensorType,
file_format: common_types.VocabularyFileFormatType,
weights: Optional[common_types.TensorType] = None,
) -> Tuple[common_types.TensorType, common_types.TensorType]:
"""Helper for constructing approximate vocabulary inputs from tensors.
Args:
x: `Tensor`, `SparseTensor`, or `RaggedTensor` to compute vocabulary over.
file_format: The format of the resulting vocabulary file.
'tfrecord_gzip' requires tensorflow>=2.4.
weights: Optional `Tensor` of weights.
Returns:
A list of batch-reduced `Tensor`s to feed to vocabulary analysis.
"""
filter_regex = analyzers.get_vocab_newline_characters_regex(
x.dtype, file_format)
reduced_batch = tf_utils.reduce_batch_weighted_counts(
x, weights=weights, force=True, filter_regex=filter_regex)
assert reduced_batch.summed_positive_per_x_and_y is None
if weights is None:
assert reduced_batch.summed_weights_per_x is None
return (reduced_batch.unique_x, reduced_batch.counts_per_x)
else:
return (reduced_batch.unique_x, reduced_batch.summed_weights_per_x)