Permalink
Find file
b994ea9 May 11, 2016
447 lines (365 sloc) 17.3 KB
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Read and preprocess image data.
Image processing occurs on a single image at a time. Image are read and
preprocessed in pararllel across mulitple threads. The resulting images
are concatenated together to form a single batch for training or evaluation.
-- Provide processed image data for a network:
inputs: Construct batches of evaluation examples of images.
distorted_inputs: Construct batches of training examples of images.
batch_inputs: Construct batches of training or evaluation examples of images.
-- Data processing:
parse_example_proto: Parses an Example proto containing a training example
of an image.
-- Image decoding:
decode_jpeg: Decode a JPEG encoded string into a 3-D float32 Tensor.
-- Image preprocessing:
image_preprocessing: Decode and preprocess one image for evaluation or training
distort_image: Distort one image for training a network.
eval_image: Prepare one image for evaluation.
distort_color: Distort the color in one image for training.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
"""Number of preprocessing threads per tower. """
"""Please make this a multiple of 4.""")
tf.app.flags.DEFINE_integer('num_readers', 4,
"""Number of parallel readers during train.""")
# Images are preprocessed asynchronously using multiple threads specifed by
# --num_preprocss_threads and the resulting processed images are stored in a
# random shuffling queue. The shuffling queue dequeues --batch_size images
# for processing on a given Inception tower. A larger shuffling queue guarantees
# better mixing across examples within a batch and results in slightly higher
# predictive performance in a trained model. Empirically,
# --input_queue_memory_factor=16 works well. A value of 16 implies a queue size
# of 1024*16 images. Assuming RGB 299x299 images, this implies a queue size of
# 16GB. If the machine is memory limited, then decrease this factor to
# decrease the CPU memory footprint, accordingly.
tf.app.flags.DEFINE_integer(
'input_queue_memory_factor', 16,
"""Size of the queue of preprocessed images. """
"""Default is ideal but try smaller values, e.g. """
"""4, 2 or 1, if host memory is constrained. See """
"""comments in code for more details.""")
def inputs(dataset, batch_size=None, num_preprocess_threads=None):
"""Generate batches of ImageNet images for evaluation.
Use this function as the inputs for evaluating a network.
Note that some (minimal) image preprocessing occurs during evaluation
including central cropping and resizing of the image to fit the network.
Args:
dataset: instance of Dataset class specifying the dataset.
batch_size: integer, number of examples in batch
num_preprocess_threads: integer, total number of preprocessing threads but
None defaults to FLAGS.num_preprocess_threads.
Returns:
images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
image_size, 3].
labels: 1-D integer Tensor of [FLAGS.batch_size].
"""
if not batch_size:
batch_size = FLAGS.batch_size
# Force all input processing onto CPU in order to reserve the GPU for
# the forward inference and back-propagation.
with tf.device('/cpu:0'):
images, labels = batch_inputs(
dataset,
batch_size,
train=False,
num_preprocess_threads=num_preprocess_threads,
num_readers=1)
return images, labels
def decode_jpeg(image_buffer, scope=None):
"""Decode a JPEG string into one 3-D float image Tensor.
Args:
image_buffer: scalar string Tensor.
scope: Optional scope for op_scope.
Returns:
3-D float Tensor with values ranging from [0, 1).
"""
with tf.op_scope([image_buffer], scope, 'decode_jpeg'):
# Decode the string as an RGB JPEG.
# Note that the resulting image contains an unknown height and width
# that is set dynamically by decode_jpeg. In other words, the height
# and width of image is unknown at compile-time.
image = tf.image.decode_jpeg(image_buffer, channels=3)
# After this point, all image pixels reside in [0,1)
# until the very end, when they're rescaled to (-1, 1). The various
# adjust_* ops all require this range for dtype float.
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
return image
def distort_color(image, thread_id=0, scope=None):
"""Distort the color of the image.
Each color distortion is non-commutative and thus ordering of the color ops
matters. Ideally we would randomly permute the ordering of the color ops.
Rather then adding that level of complication, we select a distinct ordering
of color ops for each preprocessing thread.
Args:
image: Tensor containing single image.
thread_id: preprocessing thread ID.
scope: Optional scope for op_scope.
Returns:
color-distorted image
"""
with tf.op_scope([image], scope, 'distort_color'):
color_ordering = thread_id % 2
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
# The random_* ops do not necessarily clamp.
image = tf.clip_by_value(image, 0.0, 1.0)
return image
def distort_image(image, height, width, bbox, thread_id=0, scope=None):
"""Distort one image for training a network.
Distorting images provides a useful technique for augmenting the data
set during training in order to make the network invariant to aspects
of the image that do not effect the label.
Args:
image: 3-D float Tensor of image
height: integer
width: integer
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged
as [ymin, xmin, ymax, xmax].
thread_id: integer indicating the preprocessing thread.
scope: Optional scope for op_scope.
Returns:
3-D float Tensor of distorted image used for training.
"""
with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
# NOTE(ry) I unceremoniously removed all the bounding box code.
# Original here: https://github.com/tensorflow/models/blob/148a15fb043dacdd1595eb4c5267705fbd362c6a/inception/inception/image_processing.py
distorted_image = image
# This resizing operation may distort the images because the aspect
# ratio is not respected. We select a resize method in a round robin
# fashion based on the thread number.
# Note that ResizeMethod contains 4 enumerated resizing methods.
resize_method = thread_id % 4
distorted_image = tf.image.resize_images(distorted_image, height,
width, resize_method)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([height, width, 3])
if not thread_id:
tf.image_summary('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# Randomly flip the image horizontally.
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Randomly distort the colors.
distorted_image = distort_color(distorted_image, thread_id)
if not thread_id:
tf.image_summary('final_distorted_image',
tf.expand_dims(distorted_image, 0))
return distorted_image
def eval_image(image, height, width, scope=None):
"""Prepare one image for evaluation.
Args:
image: 3-D float Tensor
height: integer
width: integer
scope: Optional scope for op_scope.
Returns:
3-D float Tensor of prepared image.
"""
with tf.op_scope([image, height, width], scope, 'eval_image'):
# Crop the central region of the image with an area containing 87.5% of
# the original image.
image = tf.image.central_crop(image, central_fraction=0.875)
# Resize the image to the original height and width.
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
return image
def image_preprocessing(image_buffer, bbox, train, thread_id=0):
"""Decode and preprocess one image for evaluation or training.
Args:
image_buffer: JPEG encoded string Tensor
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
train: boolean
thread_id: integer indicating preprocessing thread
Returns:
3-D float Tensor containing an appropriately scaled image
Raises:
ValueError: if user does not provide bounding box
"""
if bbox is None:
raise ValueError('Please supply a bounding box.')
image = decode_jpeg(image_buffer)
height = FLAGS.input_size
width = FLAGS.input_size
if train:
image = distort_image(image, height, width, bbox, thread_id)
else:
image = eval_image(image, height, width)
# Finally, rescale to [-1,1] instead of [0, 1)
image = tf.sub(image, 0.5)
image = tf.mul(image, 2.0)
return image
def parse_example_proto(example_serialized):
"""Parses an Example proto containing a training example of an image.
The output of the build_image_data.py image preprocessing script is a dataset
containing serialized Example protocol buffers. Each Example proto contains
the following fields:
image/height: 462
image/width: 581
image/colorspace: 'RGB'
image/channels: 3
image/class/label: 615
image/class/synset: 'n03623198'
image/class/text: 'knee pad'
image/object/bbox/xmin: 0.1
image/object/bbox/xmax: 0.9
image/object/bbox/ymin: 0.2
image/object/bbox/ymax: 0.6
image/object/bbox/label: 615
image/format: 'JPEG'
image/filename: 'ILSVRC2012_val_00041207.JPEG'
image/encoded: <JPEG encoded string>
Args:
example_serialized: scalar Tensor tf.string containing a serialized
Example protocol buffer.
Returns:
filename: Tensor tf.string containing the filename
label: Tensor tf.int32 containing the label.
bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
where each coordinate is [0, 1) and the coordinates are arranged as
[ymin, xmin, ymax, xmax].
text: Tensor tf.string containing the human-readable label.
"""
# Dense features in Example proto.
feature_map = {
'image/filename': tf.FixedLenFeature(
[], dtype=tf.string, default_value=''),
'image/class/label': tf.FixedLenFeature(
[1], dtype=tf.int64, default_value=-1),
'image/class/text': tf.FixedLenFeature(
[], dtype=tf.string, default_value=''),
}
sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
# Sparse features in Example proto.
feature_map.update({k: sparse_float32
for k in
['image/object/bbox/xmin', 'image/object/bbox/ymin',
'image/object/bbox/xmax', 'image/object/bbox/ymax']})
features = tf.parse_single_example(example_serialized, feature_map)
label = tf.cast(features['image/class/label'], dtype=tf.int32)
xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
# Note that we impose an ordering of (y, x) just to make life difficult.
bbox = tf.concat(0, [ymin, xmin, ymax, xmax])
# Force the variable number of bounding boxes into the shape
# [1, num_boxes, coords].
bbox = tf.expand_dims(bbox, 0)
bbox = tf.transpose(bbox, [0, 2, 1])
return features['image/filename'], label, bbox, features['image/class/text']
def batch_inputs(dataset,
batch_size,
train,
num_preprocess_threads=None,
num_readers=1):
"""Contruct batches of training or evaluation examples from the image dataset.
Args:
dataset: instance of Dataset class specifying the dataset.
See dataset.py for details.
batch_size: integer
train: boolean
num_preprocess_threads: integer, total number of preprocessing threads
num_readers: integer, number of parallel readers
Returns:
images: 4-D float Tensor of a batch of images
labels: 1-D integer Tensor of [batch_size].
Raises:
ValueError: if data is not found
"""
with tf.name_scope('batch_processing'):
data_files = dataset.data_files()
if data_files is None:
raise ValueError('No data files found for this dataset')
# Create filename_queue
if train:
filename_queue = tf.train.string_input_producer(
data_files, shuffle=True, capacity=16)
else:
filename_queue = tf.train.string_input_producer(data_files,
shuffle=False,
capacity=1)
if num_preprocess_threads is None:
num_preprocess_threads = FLAGS.num_preprocess_threads
if num_preprocess_threads % 4:
raise ValueError('Please make num_preprocess_threads a multiple '
'of 4 (%d % 4 != 0).', num_preprocess_threads)
if num_readers is None:
num_readers = FLAGS.num_readers
if num_readers < 1:
raise ValueError('Please make num_readers at least 1')
# Approximate number of examples per shard.
examples_per_shard = 1024
# Size the random shuffle queue to balance between good global
# mixing (more examples) and memory use (fewer examples).
# 1 image uses 299*299*3*4 bytes = 1MB
# The default input_queue_memory_factor is 16 implying a shuffling queue
# size: examples_per_shard * 16 * 1MB = 17.6GB
min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
if train:
examples_queue = tf.RandomShuffleQueue(
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples,
dtypes=[tf.string])
else:
examples_queue = tf.FIFOQueue(
capacity=examples_per_shard + 3 * batch_size,
dtypes=[tf.string])
reader = tf.TFRecordReader()
_, example_serialized = reader.read(filename_queue)
filename, label_index, bbox, label_text = parse_example_proto(example_serialized)
fn = FLAGS.data_dir + '/' + label_text + '/' + filename
examples_qr = tf.train.queue_runner.QueueRunner(examples_queue,
[examples_queue.enqueue([fn])])
tf.train.queue_runner.add_queue_runner(examples_qr)
images_and_labels = []
for thread_id in range(num_preprocess_threads):
# Parse a serialized Example proto to extract the image and metadata.
whole_file_reader = tf.WholeFileReader()
_, image_buffer = whole_file_reader.read(examples_queue)
image = image_preprocessing(image_buffer, bbox, train, thread_id)
images_and_labels.append([image, label_index])
images, label_index_batch = tf.train.batch_join(
images_and_labels,
batch_size=batch_size,
capacity=2 * num_preprocess_threads * batch_size)
# Reshape images into these desired dimensions.
height = FLAGS.image_size
width = FLAGS.image_size
depth = 3
images = tf.cast(images, tf.float32)
images = tf.reshape(images, shape=[batch_size, height, width, depth])
# Display the training images in the visualizer.
tf.image_summary('images', images)
return images, tf.reshape(label_index_batch, [batch_size])