Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update how_tos/reading_data to use Dataset API #14751

Merged
merged 7 commits into from
Dec 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
27 changes: 19 additions & 8 deletions tensorflow/docs_src/api_guides/python/reading_data.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,25 @@ For example,
[`tensorflow/examples/how_tos/reading_data/convert_to_records.py`](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/convert_to_records.py)
converts MNIST data to this format.

To read a file of TFRecords, use
@{tf.TFRecordReader} with
the @{tf.parse_single_example}
decoder. The `parse_single_example` op decodes the example protocol buffers into
tensors. An MNIST example using the data produced by `convert_to_records` can be
found in
[`tensorflow/examples/how_tos/reading_data/fully_connected_reader.py`](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py),
which you can compare with the `fully_connected_feed` version.
The recommended way to read a TFRecord file is with a @{tf.data.TFRecordDataset}, [as in this example](https://www.tensorflow.org/code/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py):

``` python
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.repeat(num_epochs)

# map takes a python function and applies it to every sample
dataset = dataset.map(decode)
```

To acomplish the same task with a queue based input pipeline requires the following code
(using the same `decode` function from the above example):

``` python
filename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
image,label = decode(serialized_example)
```

### Preprocessing

Expand Down
125 changes: 57 additions & 68 deletions tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@
VALIDATION_FILE = 'validation.tfrecords'


def read_and_decode(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
def decode(serialized_example):
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
Expand All @@ -60,22 +58,26 @@ def read_and_decode(filename_queue):
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([mnist.IMAGE_PIXELS])
image.set_shape((mnist.IMAGE_PIXELS))

# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)

return image, label

def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
return image, label

def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)

return image, label


def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.

Expand All @@ -91,31 +93,32 @@ def inputs(train, batch_size, num_epochs):
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
Note that an tf.train.QueueRunner is added to the graph, which
must be run using e.g. tf.train.start_queue_runners().

This function creates a one_shot_iterator, meaning that it will only iterate
over the dataset once. On the other hand there is no special initialization
required.
"""
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir,
TRAIN_FILE if train else VALIDATION_FILE)

with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
# TFRecordDataset opens a protobuf and reads entries line by line
# could also be [list, of, filenames]
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.repeat(num_epochs)

# Even when reading in multiple threads, share the filename
# queue.
image, label = read_and_decode(filename_queue)
# map takes a python function and applies it to every sample
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)

# Shuffle the examples and collect them into batch_size batches.
# (Internally uses a RandomShuffleQueue.)
# We run this in two threads to avoid being a bottleneck.
images, sparse_labels = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=2,
capacity=1000 + 3 * batch_size,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=1000)
#the parameter is the queue size
dataset = dataset.shuffle(1000 + 3 * batch_size)
dataset = dataset.batch(batch_size)

return images, sparse_labels
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()


def run_training():
Expand All @@ -124,16 +127,16 @@ def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)

# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images,
logits = mnist.inference(image_batch,
FLAGS.hidden1,
FLAGS.hidden2)

# Add to the Graph the loss calculation.
loss = mnist.loss(logits, labels)
loss = mnist.loss(logits, label_batch)

# Add to the Graph operations that train the model.
train_op = mnist.training(loss, FLAGS.learning_rate)
Expand All @@ -143,47 +146,33 @@ def run_training():
tf.local_variables_initializer())

# Create a session for running operations in the Graph.
sess = tf.Session()

# Initialize the variables (the trained variables and the
# epoch counter).
sess.run(init_op)

# Start input enqueue threads.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
step = 0
while not coord.should_stop():
start_time = time.time()

# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss])

duration = time.time() - start_time

# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
with tf.Session() as sess:
# Initialize the variables (the trained variables and the
# epoch counter).
sess.run(init_op)
try:
step = 0
while True: #train until OutOfRangeError
start_time = time.time()

# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss])

duration = time.time() - start_time

# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
duration))
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
finally:
# When done, ask the threads to stop.
coord.request_stop()

# Wait for threads to finish.
coord.join(threads)
sess.close()


step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))

def main(_):
run_training()

Expand Down