Skip to content

Commit

Permalink
Fix create_input_pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ppwwyyxx authored Dec 1, 2018
1 parent 096ed77 commit c57e822
Showing 1 changed file with 32 additions and 32 deletions.
64 changes: 32 additions & 32 deletions src/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,39 +101,39 @@ def random_rotate_image(image):
FIXED_STANDARDIZATION = 8
FLIP = 16
def create_input_pipeline(input_queue, image_size, nrof_preprocess_threads, batch_size_placeholder):
images_and_labels_list = []
for _ in range(nrof_preprocess_threads):
filenames, label, control = input_queue.dequeue()
images = []
for filename in tf.unstack(filenames):
file_contents = tf.read_file(filename)
image = tf.image.decode_image(file_contents, 3)
image = tf.cond(get_control_flag(control[0], RANDOM_ROTATE),
lambda:tf.py_func(random_rotate_image, [image], tf.uint8),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], RANDOM_CROP),
lambda:tf.random_crop(image, image_size + (3,)),
lambda:tf.image.resize_image_with_crop_or_pad(image, image_size[0], image_size[1]))
image = tf.cond(get_control_flag(control[0], RANDOM_FLIP),
lambda:tf.image.random_flip_left_right(image),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION),
lambda:(tf.cast(image, tf.float32) - 127.5)/128.0,
lambda:tf.image.per_image_standardization(image))
image = tf.cond(get_control_flag(control[0], FLIP),
lambda:tf.image.flip_left_right(image),
lambda:tf.identity(image))
#pylint: disable=no-member
image.set_shape(image_size + (3,))
images.append(image)
images_and_labels_list.append([images, label])
with tf.name_scope("scope_to_avoid_conflicts"):
images_and_labels_list = []
for _ in range(nrof_preprocess_threads):
filenames, label, control = input_queue.dequeue()
images = []
for filename in tf.unstack(filenames):
file_contents = tf.read_file(filename)
image = tf.image.decode_image(file_contents, 3)
image = tf.cond(get_control_flag(control[0], RANDOM_ROTATE),
lambda:tf.py_func(random_rotate_image, [image], tf.uint8),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], RANDOM_CROP),
lambda:tf.random_crop(image, image_size + (3,)),
lambda:tf.image.resize_image_with_crop_or_pad(image, image_size[0], image_size[1]))
image = tf.cond(get_control_flag(control[0], RANDOM_FLIP),
lambda:tf.image.random_flip_left_right(image),
lambda:tf.identity(image))
image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION),
lambda:(tf.cast(image, tf.float32) - 127.5)/128.0,
lambda:tf.image.per_image_standardization(image))
image = tf.cond(get_control_flag(control[0], FLIP),
lambda:tf.image.flip_left_right(image),
lambda:tf.identity(image))
#pylint: disable=no-member
image.set_shape(image_size + (3,))
images.append(image)
images_and_labels_list.append([images, label])

image_batch, label_batch = tf.train.batch_join(
images_and_labels_list, batch_size=batch_size_placeholder,
shapes=[image_size + (3,), ()], enqueue_many=True,
capacity=4 * nrof_preprocess_threads * 100,
allow_smaller_final_batch=True)

image_batch, label_batch = tf.train.batch_join(
images_and_labels_list, batch_size=batch_size_placeholder,
shapes=[image_size + (3,), ()], enqueue_many=True,
capacity=4 * nrof_preprocess_threads * 100,
allow_smaller_final_batch=True)
return image_batch, label_batch

def get_control_flag(control, field):
Expand Down

0 comments on commit c57e822

Please sign in to comment.