-
Notifications
You must be signed in to change notification settings - Fork 0
/
day087.py
20 lines (19 loc) · 895 Bytes
/
day087.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# When using the tf.estimator.Estimator API,
# the first two phases (Extract and Transform) are captured in the input_fn
def parse_fn(example):
"Parse TFExample records and perform simple data augmentation."
example_fmt = {
"image": tf.FixedLengthFeature((), tf.string, ""),
"label": tf.FixedLengthFeature((), tf.int64, -1)
}
parsed = tf.parse_single_example(example, example_fmt)
image = tf.image.decode_image(parsed["image"])
image = _augment_helper(image) # augments image using slice, reshape, resize_bilinear
return image, parsed["label"]
def input_fn():
files = tf.data.Dataset.list_files("/path/to/dataset/train-*.tfrecord")
dataset = files.interleave(tf.data.TFRecordDataset)
dataset = dataset.shuffle(buffer_size=FLAGS.shuffle_buffer_size)
dataset = dataset.map(map_func=parse_fn)
dataset = dataset.batch(batch_size=FLAGS.batch_size)
return dataset