Skip to content
Simple Tensorflow DatasetAPI Tutorial for reading image
Branch: master
Clone or download
Latest commit 55827b0 Feb 17, 2019
Permalink
Type Name Latest commit message Commit time
Failed to load latest commit information.
.gitignore Initial commit Mar 6, 2018
LICENSE
README.md
Templete_code.py add num_parallel_calls Mar 7, 2018
Templete_utils.py add templete code Mar 6, 2018

README.md

Tensorflow-DatasetAPI

Simple Tensorflow DatasetAPI Tutorial for reading image

Usage

1. glob images

ex) trainA_dataset = glob('./dataset/{}/*.*'.format(dataset_name + '/trainA'))

trainA_dataset = ['./dataset/cat/trainA/a.jpg', 
                  './dataset/cat/trainA/b.png', 
                  './dataset/cat/trainA/c.jpeg', 
                  ...]

2. Use from_tensor_slices

trainA = tf.data.Dataset.from_tensor_slices(trainA_dataset)

3. Use map for preprocessing

    def image_processing(filename):
        x = tf.read_file(filename) # file read 
        x_decode = tf.image.decode_jpeg(x, channels=3) # for RGB

        # DO NOT USE decode_image
        # will be error

        img = tf.image.resize_images(x_decode, [256, 256])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        return img
        

trainA = trainA.map(image_processing, num_parallel_calls=8)
  • If you want data augmentation too...
class ImageData:

    def __init__(self, batch_size, load_size, channels, augment_flag):
        self.batch_size = batch_size
        self.load_size = load_size
        self.channels = channels
        self.augment_flag = augment_flag
        self.augment_size = load_size + (30 if load_size == 256 else 15)

    def image_processing(self, filename):
        x = tf.read_file(filename)
        x_decode = tf.image.decode_jpeg(x, channels=self.channels)
        
        # DO NOT USE decode_image
        # will be error
        
        img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
        img = tf.cast(img, tf.float32) / 127.5 - 1

        if self.augment_flag :
            p = random.random()
            if p > 0.5:
                img = self.augmentation(img)

        return img
        
    def augmentation(self, image):
        seed = random.randint(0, 2 ** 31 - 1)
    
        ori_image_shape = tf.shape(image)
        image = tf.image.random_flip_left_right(image, seed=seed)
        image = tf.image.resize_images(image, [self.augment_size, self.augment_size])
        image = tf.random_crop(image, ori_image_shape, seed=seed)
    
        return image
    
    
Image_Data_Class = ImageData(batch_size, img_size, img_ch, augment_flag)
trainA = trainA.map(Image_Data_Class.image_processing, num_parallel_calls=8)
  • Personally recommend num_parallel_calls = 8 or 16

4. Set prefetch & batch_size

trainA = trainA.shuffle(buffer_size=10000).prefetch(buffer_size=batch_size).batch(batch_size).repeat()
  • Personally recommend prefetch_size = batch_size or small size
  • if shuffle_size is greater than the number of elements in the dataset, you get a uniform shuffle
  • if shuffle_size is 1 then you get no shuffling at all.
  • If the number of elements N in this dataset is not an exact multiple of batch_size, the final batch contain smaller tensors with shape N % batch_size in the batch dimension.
  • If your program depends on the batches having the same shape, consider using the tf.contrib.data.batch_and_drop_remainder transformation instead.
trainA = trainA.shuffle(10000).prefetch(batch_size).apply(batch_and_drop_remainder(batch_size)).repeat()
  • If you use the tensorflow 1.8, then this is more fast !
# hyper-parameter examples
gpu_device = '/gpu:0'
dataset_num = 10000
batch_size = 8

trainA = trainA.apply(shuffle_and_repeat(dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, batch_size))

5. Set Iterator

trainA_iterator = trainA.make_one_shot_iterator()

data_A = trainA_iterator.get_next()
logit = network(data_A)
...

6. Run Model

def train() :
    for epoch ...
        for iteration ...

7. See Code

Author

Junho Kim

You can’t perform that action at this time.