In [None]:
import os
import tensorflow as tf
import tarfile
import urllib.request

In [None]:
_root = '.'
_data_root = _root + '/data'

In [None]:
archive_file = os.path.join(_data_root, 'cifar.tgz')

if os.path.isfile(archive_file):
    print(archive_file + ' file already exists, skipping...')
else:
    urllib.request.urlretrieve('http://pjreddie.com/media/files/cifar.tgz', archive_file)
    tar = tarfile.open(archive_file, "r:gz")
    tar.extractall(_data_root)
    tar.close()

In [None]:
IMG_SIZE = 32

labels = {"airplane": 0,
          "automobile": 1,
          "bird": 2,
          "cat": 3,
          "deer": 4,
          "dog": 5,
          "frog": 6,
          "horse": 7,
          "ship": 8,
          "truck": 9}

for annotation in ('train', 'test'):
    output_file = os.path.join(_data_root, annotation + '.tfrecord')
    if os.path.isfile(output_file):
        print(output_file + ' file already exists, skipping...')
    else:
        with tf.io.TFRecordWriter(output_file) as writer:
            folder = os.path.join(_data_root, 'cifar', annotation)
            for file_name in os.listdir(folder):
                label = file_name.split("_")[1][:-4]
                label_id = labels[label]
                image_string = open(os.path.join(folder, file_name), 'rb').read()
                image = tf.image.decode_image(image_string, dtype=tf.dtypes.uint8)
                image_resized = tf.cast(tf.image.resize_images(image, size=[IMG_SIZE,IMG_SIZE],
                                                               method=tf.image.ResizeMethod.AREA),
                                        tf.uint8)
                image_string = tf.image.encode_jpeg(image_resized)
                feature = {
                    'label': tf.train.Feature(int64_list=tf.train.Int64List(
                        value=[label_id])),
                    'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(
                        value=[image_string.numpy()])),
                }
                tf_example = tf.train.Example(features=tf.train.Features(feature=feature))
                writer.write(tf_example.SerializeToString())