This repository has been archived by the owner on Jul 7, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
data_reader.py
498 lines (421 loc) · 18.2 KB
/
data_reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
# coding=utf-8
# Copyright 2018 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data reader module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import multiprocessing
import random
import six
from six.moves import range # pylint: disable=redefined-builtin
from tensor2tensor.utils import mlperf_log
import tensorflow as tf
def cast_ints_to_int32(features):
f = {}
for k, v in sorted(six.iteritems(features)):
if v.dtype in [tf.int64, tf.uint8]:
v = tf.to_int32(v)
f[k] = v
return f
def example_length(example):
length = 0
# Length of the example is the maximum length of the feature lengths
for _, v in sorted(six.iteritems(example)):
# For images the sequence length is the size of the spatial dimensions.
feature_length = tf.shape(v)[0]
if len(v.get_shape()) > 2:
feature_length = tf.shape(v)[0] * tf.shape(v)[1]
length = tf.maximum(length, feature_length)
return length
def example_valid_size(example, min_length, max_length):
length = example_length(example)
return tf.logical_and(
length >= min_length,
length <= max_length,
)
def padded_batch(dataset, batch_size, padded_shapes=None):
padded_shapes = padded_shapes or dict(
[(name, [None] * len(shape))
for name, shape in dataset.output_shapes.items()])
return dataset.padded_batch(batch_size, padded_shapes)
def _bucket_boundaries(max_length, min_length=8, length_bucket_step=1.1):
"""A default set of length-bucket boundaries."""
assert length_bucket_step > 1.0
x = min_length
boundaries = []
while x < max_length:
boundaries.append(x)
x = max(x + 1, int(x * length_bucket_step))
return boundaries
def _batching_scheme(batch_size,
max_length,
min_length_bucket,
length_bucket_step,
drop_long_sequences=False,
shard_multiplier=1,
length_multiplier=1,
min_length=0):
"""A batching scheme based on model hyperparameters.
Every batch contains a number of sequences divisible by `shard_multiplier`.
Args:
batch_size: int, total number of tokens in a batch.
max_length: int, sequences longer than this will be skipped. Defaults to
batch_size.
min_length_bucket: int
length_bucket_step: float greater than 1.0
drop_long_sequences: bool, if True, then sequences longer than
`max_length` are dropped. This prevents generating batches with
more than the usual number of tokens, which can cause out-of-memory
errors.
shard_multiplier: an integer increasing the batch_size to suit splitting
across datashards.
length_multiplier: an integer multiplier that is used to increase the
batch sizes and sequence length tolerance.
min_length: int, sequences shorter than this will be skipped.
Returns:
A dictionary with parameters that can be passed to input_pipeline:
* boundaries: list of bucket boundaries
* batch_sizes: list of batch sizes for each length bucket
* max_length: int, maximum length of an example
Raises:
ValueError: If min_length > max_length
"""
max_length = max_length or batch_size
if max_length < min_length:
raise ValueError("max_length must be greater or equal to min_length")
boundaries = _bucket_boundaries(max_length, min_length_bucket,
length_bucket_step)
boundaries = [boundary * length_multiplier for boundary in boundaries]
max_length *= length_multiplier
batch_sizes = [
max(1, batch_size // length) for length in boundaries + [max_length]
]
max_batch_size = max(batch_sizes)
# Since the Datasets API only allows a single constant for window_size,
# and it needs divide all bucket_batch_sizes, we pick a highly-composite
# window size and then round down all batch sizes to divisors of that window
# size, so that a window can always be divided evenly into batches.
# TODO(noam): remove this when Dataset API improves.
highly_composite_numbers = [
1, 2, 4, 6, 12, 24, 36, 48, 60, 120, 180, 240, 360, 720, 840, 1260, 1680,
2520, 5040, 7560, 10080, 15120, 20160, 25200, 27720, 45360, 50400, 55440,
83160, 110880, 166320, 221760, 277200, 332640, 498960, 554400, 665280,
720720, 1081080, 1441440, 2162160, 2882880, 3603600, 4324320, 6486480,
7207200, 8648640, 10810800, 14414400, 17297280, 21621600, 32432400,
36756720, 43243200, 61261200, 73513440, 110270160
]
window_size = max(
[i for i in highly_composite_numbers if i <= 3 * max_batch_size])
divisors = [i for i in range(1, window_size + 1) if window_size % i == 0]
batch_sizes = [max([d for d in divisors if d <= bs]) for bs in batch_sizes]
window_size *= shard_multiplier
batch_sizes = [bs * shard_multiplier for bs in batch_sizes]
# The Datasets API splits one window into multiple batches, which
# produces runs of many consecutive batches of the same size. This
# is bad for training. To solve this, we will shuffle the batches
# using a queue which must be several times as large as the maximum
# number of batches per window.
max_batches_per_window = window_size // min(batch_sizes)
shuffle_queue_size = max_batches_per_window * 3
ret = {
"boundaries": boundaries,
"batch_sizes": batch_sizes,
"min_length": min_length,
"max_length": (max_length if drop_long_sequences else 10**9),
"shuffle_queue_size": shuffle_queue_size,
}
return ret
def hparams_to_batching_scheme(hparams,
drop_long_sequences=False,
shard_multiplier=1,
length_multiplier=1):
"""Wrapper around _batching_scheme with hparams."""
return _batching_scheme(
batch_size=hparams.batch_size,
min_length=hparams.min_length,
max_length=hparams.max_length,
min_length_bucket=hparams.min_length_bucket,
length_bucket_step=hparams.length_bucket_step,
drop_long_sequences=drop_long_sequences,
shard_multiplier=shard_multiplier,
length_multiplier=length_multiplier)
class DummyQueueRunner(object):
"""Can stand-in for a QueueRunner but does nothing."""
def __init__(self):
pass
def create_threads(self, sess, coord=None, daemon=False, start=False):
del sess, coord, daemon, start
return []
def _pad_for_tpu(shapes_dict, hparams, max_length):
"""Pads unknown features' dimensions for TPU."""
padded_shapes = {}
def get_filler(specified_max_length):
if not specified_max_length:
return max_length
return min(specified_max_length, max_length)
inputs_none_filler = get_filler(hparams.max_input_seq_length)
targets_none_filler = get_filler(hparams.max_target_seq_length)
def pad_one_shape(shape, none_filler):
return [
(dim if dim is not None else none_filler) for dim in shape.as_list()
]
for key, shape in six.iteritems(shapes_dict):
if key == "inputs":
padded_shapes[key] = pad_one_shape(shape, inputs_none_filler)
elif key == "targets":
padded_shapes[key] = pad_one_shape(shape, targets_none_filler)
else:
padded_shapes[key] = pad_one_shape(shape, max_length)
return padded_shapes
def cpu_count():
"""Return the number of available cores."""
num_available_cores = multiprocessing.cpu_count()
return num_available_cores
def _summarize_features(features, num_shards=1):
with tf.name_scope("input_stats"):
for (k, v) in six.iteritems(features):
if isinstance(v, tf.Tensor) and v.get_shape().ndims > 1:
tf.summary.scalar("%s_batch" % k, tf.shape(v)[0] // num_shards)
tf.summary.scalar("%s_length" % k, tf.shape(v)[1])
nonpadding = tf.to_float(tf.not_equal(v, 0))
nonpadding_tokens = tf.reduce_sum(nonpadding)
tf.summary.scalar("%s_nonpadding_tokens" % k, nonpadding_tokens)
tf.summary.scalar("%s_nonpadding_fraction" % k,
tf.reduce_mean(nonpadding))
def standardize_shapes(features, batch_size=None):
"""Set the right shapes for the features."""
for fname in ["inputs", "targets"]:
if fname not in features:
continue
f = features[fname]
while len(f.get_shape()) < 4:
f = tf.expand_dims(f, axis=-1)
features[fname] = f
if batch_size:
# Ensure batch size is set on all features
for _, t in six.iteritems(features):
shape = t.get_shape().as_list()
shape[0] = batch_size
t.set_shape(t.get_shape().merge_with(shape))
# Assert shapes are fully known
t.get_shape().assert_is_fully_defined()
return features
def _are_shapes_fully_defined(shapes_dict):
for shape in shapes_dict.values():
if not shape.is_fully_defined():
return False
return True
def _file_num_records_cached(filename):
"""Return the number of TFRecords in a file."""
# Cache the result, as this is expensive to compute
if filename in _file_num_records_cache:
return _file_num_records_cache[filename]
ret = 0
for _ in tf.python_io.tf_record_iterator(filename):
ret += 1
_file_num_records_cache[filename] = ret
return ret
_file_num_records_cache = {}
def skip_random_fraction(dataset, data_file):
# Skip a random fraction at the beginning of the stream. The skip is
# essential for synchronous highly-parallel training to avoid multiple
# replicas reading the same data in lock-step.
num_skip = random.randint(0, _file_num_records_cached(data_file))
return dataset.skip(num_skip)
def _pad_batch(features, batch_multiple):
"""Pad batch dim of features to nearest multiple of batch_multiple."""
feature = list(features.items())[0][1]
batch_size = tf.shape(feature)[0]
mod = batch_size % batch_multiple
has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32)
batch_padding = batch_multiple * has_mod - mod
padded_features = {}
for k, feature in features.items():
rank = len(feature.shape)
paddings = []
for _ in range(rank):
paddings.append([0, 0])
paddings[0][1] = batch_padding
padded_feature = tf.pad(feature, paddings)
padded_features[k] = padded_feature
return padded_features
# TODO(lukaszkaiser): refactor the API to not be just a list of self params
# but make sense for other uses too.
def input_fn(dataset,
filepattern,
skip_random_fraction_when_training,
batch_size_means_tokens_param,
batch_size_multiplier,
max_length,
mode,
hparams,
data_dir=None,
params=None,
config=None,
force_repeat=False,
prevent_repeat=False):
"""Builds input pipeline for problem.
Args:
dataset: the dataset to make input function from.
filepattern: the pattern of files to read from.
skip_random_fraction_when_training: whether to skip randomly when training.
batch_size_means_tokens_param: whether batch size should mean tokens.
batch_size_multiplier: how to multiply batch size when bucketing.
max_length: maximum length,
mode: tf.estimator.ModeKeys
hparams: HParams, model hparams
data_dir: str, data directory; if None, will use hparams.data_dir
params: dict, may include "batch_size"
config: RunConfig; should have the data_parallelism attribute if not using
TPU
force_repeat: bool, whether to repeat the data even if not training
prevent_repeat: bool, whether to not repeat when in training mode.
Overrides force_repeat.
Returns:
(features_dict<str name, Tensor feature>, Tensor targets)
"""
is_training = mode == tf.estimator.ModeKeys.TRAIN
if config and config.use_tpu:
num_threads = 64
else:
num_threads = cpu_count() if is_training else 1
if config and hasattr(config,
"data_parallelism") and config.data_parallelism:
num_shards = config.data_parallelism.n
else:
num_shards = 1
mlperf_log.transformer_print(
key=mlperf_log.INPUT_MAX_LENGTH, value=max_length)
def tpu_valid_size(example):
return example_valid_size(example, hparams.min_length, max_length)
def gpu_valid_size(example):
drop_long_sequences = is_training or hparams.eval_drop_long_sequences
max_validate_length = max_length if drop_long_sequences else 10**9
return example_valid_size(example, hparams.min_length, max_validate_length)
def define_shapes(example):
batch_size = config and config.use_tpu and params["batch_size"]
return standardize_shapes(example, batch_size=batch_size)
# Read and preprocess
data_dir = data_dir or (hasattr(hparams, "data_dir") and hparams.data_dir)
if (force_repeat or is_training) and not prevent_repeat:
# Repeat and skip a random number of records
dataset = dataset.repeat()
if is_training and skip_random_fraction_when_training:
data_files = tf.contrib.slim.parallel_reader.get_data_files(filepattern)
# In continuous_train_and_eval when switching between train and
# eval, this input_fn method gets called multiple times and it
# would give you the exact same samples from the last call
# (because the Graph seed is set). So this skip gives you some
# shuffling.
dataset = skip_random_fraction(dataset, data_files[0])
dataset = dataset.map(cast_ints_to_int32, num_parallel_calls=num_threads)
if batch_size_means_tokens_param:
batch_size_means_tokens = True
else:
if _are_shapes_fully_defined(dataset.output_shapes):
batch_size_means_tokens = False
else:
tf.logging.warning(
"Shapes are not fully defined. Assuming batch_size means tokens.")
batch_size_means_tokens = True
# Batching
if not batch_size_means_tokens:
# Batch size means examples per datashard.
if config and config.use_tpu:
# on TPU, we use params["batch_size"], which specifies the number of
# examples across all datashards
batch_size = params["batch_size"]
dataset = dataset.batch(batch_size, drop_remainder=True)
else:
batch_size = hparams.batch_size * num_shards
dataset = dataset.batch(batch_size)
else:
# batch_size means tokens per datashard
if config and config.use_tpu:
dataset = dataset.filter(tpu_valid_size)
padded_shapes = _pad_for_tpu(dataset.output_shapes, hparams, max_length)
# on TPU, we use params["batch_size"], which specifies the number of
# examples across all datashards
batch_size = params["batch_size"]
if hparams.pad_batch:
tf.logging.warn(
"Padding the batch to ensure that remainder eval batches are "
"processed. This may lead to incorrect metrics for "
"non-zero-padded features, e.g. images. Use a smaller batch "
"size that has no remainder in that case.")
dataset = dataset.padded_batch(
batch_size, padded_shapes, drop_remainder=False)
dataset = dataset.map(
functools.partial(_pad_batch, batch_multiple=batch_size),
num_parallel_calls=num_threads)
else:
dataset = dataset.padded_batch(
batch_size, padded_shapes, drop_remainder=True)
else:
# On GPU, bucket by length
dataset = dataset.filter(gpu_valid_size)
batching_scheme = hparams_to_batching_scheme(
hparams,
shard_multiplier=num_shards,
length_multiplier=batch_size_multiplier)
if hparams.use_fixed_batch_size:
# Here batch_size really means examples per datashard.
batching_scheme["batch_sizes"] = [hparams.batch_size]
batching_scheme["boundaries"] = []
dataset = dataset.apply(
tf.data.experimental.bucket_by_sequence_length(
example_length, batching_scheme["boundaries"],
batching_scheme["batch_sizes"]))
if not is_training:
batch_multiple = num_shards
if hparams.use_fixed_batch_size:
# Make sure the last batch has the same fixed size as the rest.
batch_multiple *= hparams.batch_size
if batch_multiple > 1:
tf.logging.warn(
"Padding the batch to ensure that remainder eval batches have "
"a batch size divisible by the number of data shards. This may "
"lead to incorrect metrics for non-zero-padded features, e.g. "
"images. Use a single datashard (i.e. 1 GPU) in that case.")
dataset = dataset.map(
functools.partial(_pad_batch, batch_multiple=batch_multiple),
num_parallel_calls=num_threads)
dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
# Add shuffling for training batches. This is necessary along with record
# level shuffling in the dataset generation. Record shuffling will shuffle
# the examples. However, in some cases, it's possible that the shuffle
# buffer size for record shuffling is smaller than the batch size. In such
# cases, adding batch shuffling ensures that the data is in random order
# during training
if (is_training and hasattr(hparams, "batch_shuffle_size") and
hparams.batch_shuffle_size):
dataset = dataset.shuffle(hparams.batch_shuffle_size)
def prepare_for_output(example):
if not config or not config.use_tpu:
_summarize_features(example, num_shards)
if mode == tf.estimator.ModeKeys.PREDICT:
example["infer_targets"] = example.pop("targets")
return example
else:
return example, example["targets"]
dataset = dataset.map(prepare_for_output, num_parallel_calls=num_threads)
dataset = dataset.prefetch(2)
if mode == tf.estimator.ModeKeys.PREDICT:
# This is because of a bug in the Estimator that short-circuits prediction
# if it doesn't see a QueueRunner. DummyQueueRunner implements the
# minimal expected interface but does nothing.
tf.add_to_collection(tf.GraphKeys.QUEUE_RUNNERS, DummyQueueRunner())
return dataset