<a href="https://colab.research.google.com/github/wakamezake/Notebooks/blob/master/cifar10_to_TFRecords.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### References
#### Kaggle
- https://www.kaggle.com/gauravchopracg/understanding-tfrecord-format
- https://www.kaggle.com/cdeotte/how-to-create-tfrecords

#### Github
- https://github.com/zbchern/ResNet-at-Cifar10/blob/master/datasets/generate_cifar10_tfrecords.py

#### Google
- https://codelabs.developers.google.com/codelabs/keras-flowers-data#3
- https://www.tensorflow.org/guide/data_performance
- https://www.tensorflow.org/tutorials/load_data/tfrecord

In [1]:
import tensorflow as tf
tf.__version__

'2.3.0'

In [38]:
cifar10 = tf.keras.datasets.cifar10

In [39]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

In [40]:
x_train.dtype, y_train.dtype

(dtype('uint8'), dtype('uint8'))

In [43]:
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def ndarray2tfrecord(path, data):
    images, labels = data
    num_entries = len(labels)
    with tf.io.TFRecordWriter(path) as record_writer:
        for idx in range(num_entries):
            example = tf.train.Example(features=tf.train.Features(
                        feature={
                            'image': _bytes_feature(images[idx].tobytes),
                            'label': _int64_feature(labels[idx])
                        }))
            record_writer.write(example.SerializeToString())

In [44]:
# https://numpy.org/doc/stable/reference/generated/numpy.ndarray.tobytes.html
tf.train.Feature(bytes_list=tf.train.BytesList(value=[x_train[0].tobytes()]))

bytes_list {
  value: ";>?+.-20+D6*bI4w[?\213kK\221nP\225uY\225x]\203gM}cL\216s[\220pV\211iO\201aG\211jO\206jL|a@\213qN\213pK\205iE\210iJ\213lM\230xY\243\203d\250\210l\237\201f\236\202h\236\204l\230}f\224|g\020\024\024\000\000\000\022\010\0003\033\010X3\025xR+\200Y-\177V,~W2tO,jF%eC#iF$qJ#mF!pH%wO,mG!iE\033}Y.\177\\.zU\'\203Y/|R)yO%\203Y0\204[5\205^:\205`<{X7wS2zW9\031\030\025\020\007\0001\033\010S2\027nH)\201\\6\202]7yR/qM+pN,pO.jK-iI&\200\\0|W/\202\\8\177Y8zU3sO+xU/\202_6\203`7\213f>\177Z3~Y1\177Y2\202\\5\216iD\202^:vT2xT2mI*!\031\021&\024\004W6\031j?\034sF!uJ#rH%i>!kD!yT-}Z5mK(qM&\222i:\205[/\177T-vL(uL)\177W4zQ+\204\\3\211c:\210c9\203]4|V,\202[2\204Z1\207]3\202Z2}W2yU0^>#2 \025; \013fA\"\177O\'|M$yM$xN(rJ\'kH\"}X1\201Y3jD\037lG!|S*yN\'lD\035bA\027nJ%uP1xP)\206]2\214jB\203_:\215bB\207\\3\177T-yO)wO(gC W9\033K/\027C*\031G0\035T5\030nI%\201R&\210X-\203T*\201T+wM%lF!zR,{Q\'iA\031kH\037oM\037lJ\"bA\033^>\025a? S8&X:$fD*aE.X6$vJH\214`O\210a@xP-kD\"X6\030C\'\017#\n\000 \r\004aE(oK${U+\202

In [45]:
ndarray2tfrecord('train.tfrec', (x_train, y_train))

In [46]:
ndarray2tfrecord('test.tfrec', (x_test, y_test))

In [48]:
train_dataset = tf.data.TFRecordDataset(['train.tfrec'])
test_dataset = tf.data.TFRecordDataset(['test.tfrec'])

In [50]:
for batch in train_dataset.take(1):
    print(repr(batch))

<tf.Tensor: shape=(), dtype=string, numpy=b'\n\xa3\x18\n\x0e\n\x05label\x12\x05\x1a\x03\n\x01\x06\n\x90\x18\n\x05image\x12\x86\x18\n\x83\x18\n\x80\x18;>?+.-20+D6*bI4w[?\x8bkK\x91nP\x95uY\x95x]\x83gM}cL\x8es[\x90pV\x89iO\x81aG\x89jO\x86jL|a@\x8bqN\x8bpK\x85iE\x88iJ\x8blM\x98xY\xa3\x83d\xa8\x88l\x9f\x81f\x9e\x82h\x9e\x84l\x98}f\x94|g\x10\x14\x14\x00\x00\x00\x12\x08\x003\x1b\x08X3\x15xR+\x80Y-\x7fV,~W2tO,jF%eC#iF$qJ#mF!pH%wO,mG!iE\x1b}Y.\x7f\\.zU\'\x83Y/|R)yO%\x83Y0\x84[5\x85^:\x85`<{X7wS2zW9\x19\x18\x15\x10\x07\x001\x1b\x08S2\x17nH)\x81\\6\x82]7yR/qM+pN,pO.jK-iI&\x80\\0|W/\x82\\8\x7fY8zU3sO+xU/\x82_6\x83`7\x8bf>\x7fZ3~Y1\x7fY2\x82\\5\x8eiD\x82^:vT2xT2mI*!\x19\x11&\x14\x04W6\x19j?\x1csF!uJ#rH%i>!kD!yT-}Z5mK(qM&\x92i:\x85[/\x7fT-vL(uL)\x7fW4zQ+\x84\\3\x89c:\x88c9\x83]4|V,\x82[2\x84Z1\x87]3\x82Z2}W2yU0^>#2 \x15; \x0bfA"\x7fO\'|M$yM$xN(rJ\'kH"}X1\x81Y3jD\x1flG!|S*yN\'lD\x1dbA\x17nJ%uP1xP)\x86]2\x8cjB\x83_:\x8dbB\x87\\3\x7fT-yO)wO(gC W9\x1bK/\x17C*\x19G0\x1dT5\x18nI%\x81R&\x88X-\x83T*\x81T+wM

In [77]:
schema = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64),
}

WIDTH, HEIGHT = 32, 32
IMAGE_SIZE = [WIDTH, HEIGHT]

def _parse_func(_example):
    features = tf.io.parse_single_example(
        serialized=_example,
        features=schema
    )
    return features

def decode_image(image):
    _img = tf.io.decode_raw(image, tf.uint8)
    _img = tf.cast(_img, tf.float32) / 255.0
    _img = tf.reshape(_img, [*IMAGE_SIZE, 3])
    return _img

def read_labeled_tfrecord(example):
    image_features = _parse_func(example)
    image = decode_image(image_features['image'])
    label = image_features['label']
    return image, label

def load_dataset(filenames):
    _dataset = tf.data.TFRecordDataset(filenames)
    _dataset = _dataset.map(read_labeled_tfrecord)
    return _dataset

def get_dataset(filenames, repeated=False, ordered=False):
    dataset = load_dataset(filenames)
    if repeated:
        dataset = dataset.repeat()
    if not ordered:
        dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

In [66]:
for image_features in train_dataset.map(_parse_func).take(1):
    image = decode_image(image_features['image'])
    label = image_features['label']
    print(image)
    print(label)

tf.Tensor(
[[[0.23137255 0.24313726 0.24705882]
  [0.16862746 0.18039216 0.1764706 ]
  [0.19607843 0.1882353  0.16862746]
  ...
  [0.61960787 0.5176471  0.42352942]
  [0.59607846 0.49019608 0.4       ]
  [0.5803922  0.4862745  0.40392157]]

 [[0.0627451  0.07843138 0.07843138]
  [0.         0.         0.        ]
  [0.07058824 0.03137255 0.        ]
  ...
  [0.48235294 0.34509805 0.21568628]
  [0.46666667 0.3254902  0.19607843]
  [0.47843137 0.34117648 0.22352941]]

 [[0.09803922 0.09411765 0.08235294]
  [0.0627451  0.02745098 0.        ]
  [0.19215687 0.10588235 0.03137255]
  ...
  [0.4627451  0.32941177 0.19607843]
  [0.47058824 0.32941177 0.19607843]
  [0.42745098 0.28627452 0.16470589]]

 ...

 [[0.8156863  0.6666667  0.3764706 ]
  [0.7882353  0.6        0.13333334]
  [0.7764706  0.6313726  0.10196079]
  ...
  [0.627451   0.52156866 0.27450982]
  [0.21960784 0.12156863 0.02745098]
  [0.20784314 0.13333334 0.07843138]]

 [[0.7058824  0.54509807 0.3764706 ]
  [0.6784314  0.48235294 0

In [68]:
AUTO = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 32

In [73]:
TRAINING_FILENAMES = tf.io.gfile.glob('train*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob('test*.tfrec')

In [78]:
ds_train = get_dataset(TRAINING_FILENAMES, ordered=False)
ds_valid = get_dataset(TEST_FILENAMES, ordered=True)

In [79]:
from tensorflow.keras import datasets, layers, models

In [80]:
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10))

In [83]:
EPOCHS = 3
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(ds_train, validation_data=ds_valid, epochs=EPOCHS)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7ff6e7630da0>