In [12]:
import tensorflow as tf
import os
import sys
import glob
import random
import shutil
import numpy as np
import matplotlib.image as mpimg
from matplotlib import pyplot as plt
from tqdm import trange
from math import ceil

%matplotlib inline

In [27]:
class TFRecordGenerator:
    
    def __init__(self):
        pass
    
    def convert(self,
                directory,
                tfrecord_file_name,
                suffix='tfrecord',
                num_files_per_record=1000,
                shuffle=True,
                encode_image_shape=False):
        
        image_paths = get_image_paths(get_subdirectories(directory))
        
        if shuffle:
            random.shuffle(image_paths)
            
        num_files_total = len(image_paths)
        num_tfrecords = ceil(num_files_total / num_files_per_record)
        start = 0
        tfrecord_file_number = 0
        
        while start < num_files_total:
            num_files_remaining = num_files_total - start
            batch_size = num_files_per_record if (num_files_remaining >= num_files_per_record) else num_files_remaining
            with tf.python_io.TFRecordWriter(tfrecord_file_name + '_{:04d}.{}'.format(tfrecord_file_number, suffix)) as writer:
                for i in trange(batch_size, desc="Creating TFRecord file {}/{}".format(tfrecord_file_number+1, num_tfrecords), file=sys.stdout):
                    example = self._convert_sample(image_paths[start+i], encode_image_shape) # This is an instance of tf.Example
                    writer.write(example.SerializeToString())
            start += num_files_per_record
            tfrecord_file_number += 1
                
    def _convert_sample(self, image_path, encode_image_shape=False):
        
        # Convert the image.
        image_feature = self._convert_image(image_path, encode_image_shape)
        
        # Convert the label.
        label_feature = self._convert_image_class_from_file_name(image_path)
        
        return tf.train.Example(features = tf.train.Features(feature = {**image_feature, **label_feature}))
    
    def _convert_image(self, image_path, encode_image_shape=False):
        '''
        Converts an image and returns a dictionary of `tf.train.Feature` objects,
        which is the input to `tf.train.Features`.
        '''
        
        file_name = os.path.basename(image_path)
        
        # Read a byte representation of the image.
        with tf.gfile.GFile(image_path, 'rb') as fid:
            image = fid.read()
            
        if encode_image_shape:
            image_shape = mpimg.imread(image_path).shape
            if len(image_shape) == 2:
                image_shape += (1,)
            return {
                'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
                'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [file_name.encode('utf-8')])),
                'height': tf.train.Feature(int64_list = tf.train.Int64List(value = [image_shape[0]])),
                'width': tf.train.Feature(int64_list = tf.train.Int64List(value = [image_shape[1]])),
                'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [image_shape[2]]))
            }
               
        return {
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
            'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [file_name.encode('utf-8')]))
        }
            
    def _convert_image_class_from_file_name(self, image_path):
        
        label = get_image_class_from_file_name(image_path)
        return {
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
        }

In [28]:
input_directory = '/pierluigiferrari/datasets/256_ObjectCategories/'
output_file = '/pierluigiferrari/datasets/Caltech256/Caltecth256'
extract_directory = '/pierluigiferrari/datasets/Extracted_TFRecords/'

In [29]:
tfrecord_generator = TFRecordGenerator()

tfrecord_generator.convert(input_directory, output_file, encode_image_shape=False)

Creating TFRecord file 1/31: 100%|██████████| 1000/1000 [00:00<00:00, 5737.74it/s]
Creating TFRecord file 2/31: 100%|██████████| 1000/1000 [00:00<00:00, 5474.69it/s]
Creating TFRecord file 3/31: 100%|██████████| 1000/1000 [00:00<00:00, 5699.81it/s]
Creating TFRecord file 4/31: 100%|██████████| 1000/1000 [00:00<00:00, 5364.11it/s]
Creating TFRecord file 5/31: 100%|██████████| 1000/1000 [00:00<00:00, 5558.59it/s]
Creating TFRecord file 6/31: 100%|██████████| 1000/1000 [00:00<00:00, 5784.88it/s]
Creating TFRecord file 7/31: 100%|██████████| 1000/1000 [00:00<00:00, 5258.14it/s]
Creating TFRecord file 8/31: 100%|██████████| 1000/1000 [00:00<00:00, 5531.70it/s]
Creating TFRecord file 9/31: 100%|██████████| 1000/1000 [00:00<00:00, 5504.94it/s]
Creating TFRecord file 10/31: 100%|██████████| 1000/1000 [00:00<00:00, 5703.04it/s]
Creating TFRecord file 11/31: 100%|██████████| 1000/1000 [00:00<00:00, 5737.04it/s]
Creating TFRecord file 12/31: 100%|██████████| 1000/1000 [00:00<00:00, 5604.70it/s]
C

In [8]:
def get_subdirectories(directory, include_top=True):
    '''
    Return a list of all subdirectories of `directory`.
    '''
    
    subdirectories = []
    
    if include_top:
        subdirectories.append(directory)
    
    for dirpath, dirnames, filenames in os.walk(top=directory, topdown=True):
        
        subdirs = [os.path.join(dirpath, dirname) for dirname in dirnames]
        subdirectories += subdirs
        
    return subdirectories

In [9]:
def get_image_paths(directories, extensions=['jpg','jpeg','png']):
    '''
    Return a list of all image paths in `directories`.
    
    Arguments:
        directories (list): A list of directory paths to iterate over.
        extensions (list, optional): An optional list of strings that
            define all acceptable file extensions. If `None`, any file
            extension is acceptable.
    '''
    
    image_paths = []
    
    if extensions is None:
    
        for directory in directories:

            image_paths += glob.glob(os.path.join(directory, '*'))
            
    else:
            
        for directory in directories:
            
            for extension in extensions:

                image_paths += glob.glob(os.path.join(directory, '*.'+extension))
                
    return image_paths

In [10]:
def get_image_class_from_file_name(image_path, separator='_'):
    '''
    Return the class ID of an image (i.e. an integer) based on the beginning of the
    image name string
    '''
    
    return int(os.path.basename(image_path).split(separator)[0])

In [103]:
class TFRecordExtractor:
    
    def __init__(self, tfrecord_file):
        self.tfrecord_file = os.path.abspath(tfrecord_file)

    def _extract_fn(self, tfrecord):
        
        # Extract features using the keys set during creation
        features = {
            'filename': tf.FixedLenFeature([], tf.string),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'channels': tf.FixedLenFeature([], tf.int64),
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64)
        }

        # Extract the data record
        sample = tf.parse_single_example(tfrecord, features)

        image = tf.image.decode_image(sample['image'])        
        img_shape = tf.stack([sample['height'], sample['width'], sample['channels']])
        label = sample['label']
        filename = sample['filename']
        return [image, label, filename, img_shape]

    def extract_image(self):
        # Create folder to store extracted images
        folder_path = extract_directory
        shutil.rmtree(folder_path, ignore_errors = True)
        os.mkdir(folder_path)

        # Pipeline of dataset and iterator
        dataset = tf.data.TFRecordDataset([self.tfrecord_file])
        dataset = dataset.map(self._extract_fn)
        iterator = dataset.make_one_shot_iterator()
        next_image_data = iterator.get_next()

        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())

            try:
                # Keep extracting data till TFRecord is exhausted
                while True:
                    image_data = sess.run(next_image_data)

                    # Check if image shape is same after decoding
                    if not np.array_equal(image_data[0].shape, image_data[3]):
                        print('Image {} not decoded properly'.format(image_data[2]))
                        continue

                    save_path = os.path.abspath(os.path.join(folder_path, image_data[2].decode('utf-8')))
                    image = np.squeeze(image_data[0])
                    mpimg.imsave(save_path, image)
                    print('Save path = ', save_path, ', Label = ', image_data[1])
            except:
                pass

In [105]:
tfrecord_file = '/pierluigiferrari/datasets/Caltech256/Caltech256_0030.tfrecord'

tfrecord_extractor = TFRecordExtractor(tfrecord_file)

tfrecord_extractor.extract_image()

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/152_0072.jpg , Label =  152
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/257_0244.jpg , Label =  257
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/164_0015.jpg , Label =  164
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/193_0081.jpg , Label =  193
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/182_0088.jpg , Label =  182
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/079_0038.jpg , Label =  79
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/255_0102.jpg , Label =  255
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/239_0033.jpg , Label =  239
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/218_0010.jpg , Label =  218
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/142_0047.jpg , Label =  142
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/126_0132.jpg , Label =  126
Save path =  /pierluigiferrari/da

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/059_0024.jpg , Label =  59
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/221_0037.jpg , Label =  221
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/198_0007.jpg , Label =  198
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/199_0089.jpg , Label =  199
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/202_0058.jpg , Label =  202
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/162_0029.jpg , Label =  162
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/003_0087.jpg , Label =  3
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/125_0045.jpg , Label =  125
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/251_0506.jpg , Label =  251
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/020_0067.jpg , Label =  20
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/214_0061.jpg , Label =  214
Save path =  /pierluigiferrari/datas

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/137_0084.jpg , Label =  137
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/193_0069.jpg , Label =  193
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/129_0033.jpg , Label =  129
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/129_0140.jpg , Label =  129
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/046_0041.jpg , Label =  46
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/058_0064.jpg , Label =  58
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/120_0129.jpg , Label =  120
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/102_0039.jpg , Label =  102
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/076_0079.jpg , Label =  76
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/145_0401.jpg , Label =  145
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/127_0086.jpg , Label =  127
Save path =  /pierluigiferrari/data

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/247_0051.jpg , Label =  247
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/216_0097.jpg , Label =  216
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/196_0012.jpg , Label =  196
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/254_0049.jpg , Label =  254
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/246_0088.jpg , Label =  246
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/253_0203.jpg , Label =  253
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/052_0037.jpg , Label =  52
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/083_0031.jpg , Label =  83
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/191_0008.jpg , Label =  191
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/161_0009.jpg , Label =  161
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/256_0077.jpg , Label =  256
Save path =  /pierluigiferrari/dat

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/056_0023.jpg , Label =  56
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/185_0038.jpg , Label =  185
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/251_0463.jpg , Label =  251
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/023_0015.jpg , Label =  23
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/145_0237.jpg , Label =  145
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/212_0029.jpg , Label =  212
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/139_0002.jpg , Label =  139
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/073_0048.jpg , Label =  73
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/124_0031.jpg , Label =  124
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/227_0077.jpg , Label =  227
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/221_0072.jpg , Label =  221
Save path =  /pierluigiferrari/data

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/009_0073.jpg , Label =  9
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/105_0070.jpg , Label =  105
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/003_0100.jpg , Label =  3
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/251_0314.jpg , Label =  251
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/079_0069.jpg , Label =  79
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/156_0080.jpg , Label =  156
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/059_0060.jpg , Label =  59
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/028_0032.jpg , Label =  28
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/143_0111.jpg , Label =  143
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/145_0682.jpg , Label =  145
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/131_0087.jpg , Label =  131
Save path =  /pierluigiferrari/datasets

Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/051_0074.jpg , Label =  51
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/255_0071.jpg , Label =  255
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/012_0112.jpg , Label =  12
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/078_0019.jpg , Label =  78
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/248_0045.jpg , Label =  248
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/253_0124.jpg , Label =  253
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/026_0085.jpg , Label =  26
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/165_0039.jpg , Label =  165
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/185_0023.jpg , Label =  185
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/079_0010.jpg , Label =  79
Save path =  /pierluigiferrari/datasets/Extracted_TFRecords/105_0109.jpg , Label =  105


In [32]:
class TFRecordDataset:
    
    def __init__(self,
                 tfrecord_paths,
                 encode_image_shape=False):
        
        self.tfrecord_paths = [os.path.abspath(tfrecord_path) for tfrecord_path in tfrecord_paths]
        self.encode_image_shape = encode_image_shape
        
    def _parse_tfrecord(self, single_example_proto):
        
        # Get the feature definition that was set during creation of the TFRecord.
        feature_definition = self._get_feature_definition()

        # Parse a single sample from the TFRecord.
        sample = tf.parse_single_example(single_example_proto, feature_definition)

        # Decode the image.
        image = tf.image.decode_image(sample['image'])
        filename = sample['filename']
        label = sample['label']
        if self.encode_image_shape:
            img_shape = tf.stack([sample['height'], sample['width'], sample['channels']])
            return image, label, filename, img_shape
        return image, label, filename
    
    def _get_feature_definition(self):
        
        if self.encode_image_shape:
            return {'image': tf.FixedLenFeature([], tf.string),
                    'filename': tf.FixedLenFeature([], tf.string),
                    'height': tf.FixedLenFeature([], tf.int64),
                    'width': tf.FixedLenFeature([], tf.int64),
                    'channels': tf.FixedLenFeature([], tf.int64),
                    'label': tf.FixedLenFeature([], tf.int64)}
        else:
            return {'image': tf.FixedLenFeature([], tf.string),
                    'filename': tf.FixedLenFeature([], tf.string),
                    'label': tf.FixedLenFeature([], tf.int64)}
        
    def create_dataset(self):
        
        dataset = tf.data.TFRecordDataset(self.tfrecord_paths, num_parallel_reads=None)
        dataset = dataset.map(self._parse_tfrecord, num_parallel_calls=None)