In [1]:
import numpy as np
import tensorflow as tf
import pickle
import os

# 如何将 CIFAR10 数据集写成 TFRecords 格式

CIFAR10 数据集可以从[这里](https://www.cs.toronto.edu/~kriz/cifar.html)下载，选择 Python 版本，下载后放到 `cifar10-dataset` 文件夹。

## 首先将数据集读取成 numpy 数组

In [88]:
dataset_dir = 'cifar10-dataset'

In [89]:
train_filenames = [
    'data_batch_1',
    'data_batch_2',
    'data_batch_3',
    'data_batch_4',
    'data_batch_5'
]
test_filenames = [
    'test_batch'
]

In [90]:
def unpickle(filename):
    '''Decode the dataset files.'''
    with open(filename, 'rb') as f:
        d = pickle.load(f, encoding='latin1')
        return d

In [91]:
train_images = unpickle(os.path.join(dataset_dir, train_filenames[0]))['data']
train_labels = unpickle(os.path.join(dataset_dir, train_filenames[0]))['labels']
test_images = unpickle(os.path.join(dataset_dir, test_filenames[0]))['data']
test_labels = unpickle(os.path.join(dataset_dir, test_filenames[0]))['labels']

In [43]:
for i in range(1, len(train_filenames)):
    batch = unpickle(os.path.join(dataset_dir, train_filenames[i]))
    train_images = np.concatenate((train_images, batch['data']), axis=0)
    train_labels = np.concatenate((train_labels, batch['labels']), axis=0)

In [48]:
num_examples = train_images.shape[0]
num_examples

50000

In [92]:
test_num_examples = test_images.shape[0]
test_num_examples

10000

In [49]:
train_images[0]

array([ 59,  43,  50, ..., 140,  84,  72], dtype=uint8)

In [44]:
train_images.shape

(50000, 3072)

In [45]:
train_labels.shape

(50000,)

In [46]:
train_images.dtype

dtype('uint8')

## 开始写入 TFRecords 文件

In [47]:
filename = 'train.tfrecords'

In [93]:
test_tfr_filename = 'eval.tfrecords'

In [94]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [51]:
# 训练文件
with tf.python_io.TFRecordWriter(filename) as writer:
    for index in range(num_examples):
        image_raw = train_images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(int(train_labels[index])),
            'image_raw': _bytes_feature(image_raw)
        }))
        writer.write(example.SerializeToString())

In [95]:
# 验证文件
with tf.python_io.TFRecordWriter(test_tfr_filename) as writer:
    for index in range(test_num_examples):
        image_raw = test_images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': _int64_feature(int(test_labels[index])),
            'image_raw': _bytes_feature(image_raw)
        }))
        writer.write(example.SerializeToString())

## 读取 TFRecords 文件

In [47]:
filenames = ['train.tfrecords']

In [48]:
sess = tf.InteractiveSession()

In [49]:
dataset = tf.data.TFRecordDataset(filenames)

In [50]:
def parser(record):
    keys_to_features={
        'image_raw': tf.FixedLenFeature((), tf.string),
        'label': tf.FixedLenFeature((), tf.int64)
    }
    parsed = tf.parse_single_example(record, keys_to_features)
    image = tf.decode_raw(parsed['image_raw'], tf.uint8)
    image = tf.cast(image, tf.float32)
#     image = tf.reshape(image, [32, 32, 3])
    label = tf.cast(parsed['label'], tf.int32)
    return image, label

In [51]:
dataset = dataset.map(parser)

In [52]:
dataset = dataset.batch(5)

In [53]:
dataset = dataset.repeat()

In [54]:
iterator = dataset.make_one_shot_iterator()

In [55]:
iterator.output_types

(tf.float32, tf.int32)

In [56]:
iterator.output_shapes

(TensorShape([Dimension(None), Dimension(None)]),
 TensorShape([Dimension(None)]))

In [57]:
feature, label = iterator.get_next()

In [58]:
feature

<tf.Tensor 'IteratorGetNext_3:0' shape=(?, ?) dtype=float32>

In [59]:
sess.run(label)

array([6, 9, 9, 4, 1])

In [60]:
image = sess.run(feature)
image.shape

(5, 3072)

In [61]:
image[0]

array([ 159.,  150.,  153., ...,   14.,   17.,   19.], dtype=float32)

In [46]:
image.astype('float32')

array([[ 159.,  150.,  153., ...,   14.,   17.,   19.],
       [ 164.,  105.,  118., ...,   29.,   26.,   44.],
       [  28.,   30.,   33., ...,  100.,   99.,   96.],
       [ 134.,  131.,  128., ...,  136.,  137.,  138.],
       [ 125.,  110.,  102., ...,   82.,   84.,   86.]], dtype=float32)