In [None]:
import tensorflow as tf
import os
import numpy as np
import sys

In [None]:
def _img_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))

def process_img(filename, size=None, crop=True):
    image_string = tf.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)
    
    if crop:
        image_shape = tf.image.extract_jpeg_shape(image_string)
        h = image_shape[0]
        w = image_shape[1]
        s = tf.minimum(h, w)
        image_decoded = tf.image.resize_image_with_crop_or_pad(image_decoded, s, s)

    image_resized = tf.image.resize_images(image_decoded, size) if size is not None else image_decoded
    return tf.to_float(image_resized)

def make_example(imgs):
    imgs = tf.train.FeatureList(feature=[_img_feature(img) for img in imgs])
    feature_lists = tf.train.FeatureLists(feature_list={'imgs':imgs})
    return tf.train.SequenceExample(feature_lists=feature_lists)

def create_pipeline(sess, params):
    files = tf.data.Dataset.list_files(params.train_path)
    files = files.shuffle(params.total_train_sample)
    files = files.take(params.total_train_sample)
    files = files.map(lambda x: process_img(x, params.input_shape[0:2], crop=True))
    files = files.batch(params.batch_size)
    return files.make_initializable_iterator()

def write():
    class Params:
        def __init__(self):
            # self.train_path = '/Users/paul/Work/ai/images/val2017/*.jpg'
            # work
            self.train_path = 'data/train/val2017/*.jpg'
            self.out_path = 'data/train/val'
            self.total_train_sample = 5000
            self.batch_size = 4
            self.input_shape = [256,256]

    params = Params()
            
    tf.reset_default_graph()
    sess = tf.InteractiveSession()

    iterator = create_pipeline(sess, params)
    next_batch = iterator.get_next()

    sess.run(iterator.initializer)
    b = 0
    print('Starting\n')
    while True:
        try:
            batch = sess.run(next_batch)
            output_file = os.path.join(params.out_path, str(b) + ".tfr")
            
            with tf.python_io.TFRecordWriter(output_file) as writer:
                writer.write(make_example(batch).SerializeToString())
            
            b += 1
            sys.stdout.write("\r")
            sys.stdout.write("Done %i/%i" % (b, params.total_train_sample/params.batch_size))
        except tf.errors.OutOfRangeError:
            break

    print("\nDone")

write()

In [None]:
class TrainParams:
     def __init__(self):
        self.train_path = 'data/train/val/*.tfr'
        self.total_train_sample = 8
        self.batch_size = 4
        self.input_shape = [256,256]

def process_tf(x, batch_size, shape=None):
    context, parsed_features = tf.parse_single_sequence_example(x, sequence_features={
        'imgs':tf.FixedLenSequenceFeature([shape[0]*shape[1]*3], dtype=tf.float32)
    })
    imgs = parsed_features['imgs']
    imgs = tf.reshape(imgs, [batch_size] + shape + [3])
    return imgs
        
def create_tf_pipeline(sess, params):
    filenames = tf.data.Dataset.list_files(params.train_path).take(params.total_train_sample)
    filenames_iterator = filenames.make_one_shot_iterator()
    return filenames_iterator

def read():
    params = TrainParams()
        
    tf.reset_default_graph()
    sess = tf.InteractiveSession()

    filenames = tf.data.Dataset.list_files(params.train_path).take(params.total_train_sample)
    filenames_iterator = filenames.make_one_shot_iterator()
    next_filenames = filenames_iterator.get_next()

    while True:
        try:
            files = tf.data.TFRecordDataset(next_filenames.eval())
            files = files.map(lambda x: process_tf(x, params.batch_size, params.input_shape[0:2]))
            files_iterator = files.make_one_shot_iterator()
            next_files = files_iterator.get_next()
            batch = sess.run(next_files)

            # for img in batch:
            print(batch.shape)
        except tf.errors.OutOfRangeError:
            break

    print("Done")