Load the Fashion MNIST dataset (introduced in Chapter 10); split it into a training set, a validation set, and a test set; shuffle the training set; and save each dataset to multiple TFRecord files. Each record should be a serialized Example protobuf with two features: the serialized image (use tf.io.serialize_tensor() to serialize each image), and the label (For large images, you could use tf.io.encode_jpeg() instead. This would save a lot of space, but it would lose a bit of image quality) Then use tf.data to create an efficient dataset for each set. Finally, use a Keras model to train these datasets, including a preprocessing layer to standardize each input feature. Try to make the input pipeline as efficient as possible, using TensorBoard to visualize profiling data.

In [1]:
import tensorflow as tf

# Splitting to files

In [12]:
(x_train_full, y_train_full), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_valid = x_train_full[:-5000], x_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]


def to_protobuf(image, label):
    image_bytes = tf.io.serialize_tensor(image).numpy()

    image_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes]))
    label_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))

    return tf.train.Example(features=tf.train.Features(
        feature={
            'image': image_feature,
            'label': label_feature
        }
    ))

def dump_batch(batch, subfolder, id):
    file_name = f'data/mnist/{subfolder}/{str(id).zfill(4)}.tfrecord'
    with tf.io.TFRecordWriter(file_name) as f:
        for index, item in enumerate(batch['image']):
            protobuf = to_protobuf(item, batch['label'][index])
            bytes = protobuf.SerializeToString()
            f.write(bytes)

def preprocess(x, y, subfolder):
    ds = tf.data.Dataset.from_tensor_slices({
        'image': x,
        'label': y
    }).shuffle(buffer_size=10000).batch(100)
    for id, batch in enumerate(ds):
        dump_batch(batch, subfolder, id)

# preprocess(x_test, y_test, 'test')
# preprocess(x_train, y_train, 'train')
# preprocess(x_valid, y_valid, 'valid')

# Reading files

In [10]:
feature_descriptions = {
    "image": tf.io.FixedLenFeature([], tf.string),
    "label": tf.io.FixedLenFeature([], tf.int64)
}

def parse(example_proto):
    example = tf.io.parse_single_example(example_proto, feature_descriptions)
    image = tf.io.parse_tensor(example['image'], out_type=tf.uint8)
    image = tf.reshape(image, [28, 28])
    label = example['label']
    return image, label

In [14]:
batch_size = 32

training_files = tf.data.Dataset.list_files('data/mnist/train/*.tfrecord')
training = tf.data.TFRecordDataset(training_files, num_parallel_reads=5).shuffle(10000).map(parse, num_parallel_calls=5).batch(batch_size).prefetch(1)

valid_files = tf.data.Dataset.list_files('data/mnist/valid/*.tfrecord')
valid = tf.data.TFRecordDataset(valid_files, num_parallel_reads=5).map(parse, num_parallel_calls=5).batch(batch_size).prefetch(1)

test_files = tf.data.Dataset.list_files('data/mnist/test/*.tfrecord')
test = tf.data.TFRecordDataset(test_files, num_parallel_reads=5).map(parse, num_parallel_calls=5).batch(batch_size).prefetch(1)

In [26]:
for x in training.take(1):
    print(x[0].shape)
    print(x[1].shape)

(32, 28, 28)
(32,)


# The model

In [29]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Rescaling(scale=1/255),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(30, activation='relu', kernel_initializer='he_normal'),
    tf.keras.layers.Dense(30, activation='relu', kernel_initializer='he_normal'),
    tf.keras.layers.Dense(10, activation='softmax'),
])

model.compile(
    loss=tf.keras.losses.sparse_categorical_crossentropy,
    optimizer=tf.keras.optimizers.legacy.Nadam(),
    metrics=[tf.keras.metrics.sparse_categorical_accuracy],
)

tensorboard_cb = tf.keras.callbacks.TensorBoard('data/mnist/tensorboard')
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint('data/mnist/checkpoints', save_best_only=True)

hist = model.fit(
    training,
    epochs=10,
    validation_data=valid,
    callbacks=[tensorboard_cb, checkpoint_cb]
)

Epoch 1/10
   1719/Unknown - 22s 11ms/step - loss: 0.3595 - sparse_categorical_accuracy: 0.8955INFO:tensorflow:Assets written to: data/mnist/checkpoints/assets
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [30]:
model.evaluate(test)



[0.1292203962802887, 0.9640000462532043]