In [41]:
import tensorflow as tf 
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import StandardScaler
import tensorflow_datasets as tfds
import os
import contextlib

In [42]:
def convert_to_example_protobuf(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(image).numpy()])),
        'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label.numpy()]))
    }))

def write_multiple_tfrecords(folder, prefix, dataset, n_shards):
    os.makedirs(folder, exist_ok=True)
    paths = [f'{folder}/{prefix}-{i:03d}-of-{n_shards:03d}' for i in range(n_shards)]
    with contextlib.ExitStack() as stack:
        writers = [stack.enter_context(tf.io.TFRecordWriter(path)) for path in paths]
        for index, (image, label) in enumerate(dataset):
            shard = index % n_shards
            example = convert_to_example_protobuf(image, label)
            writers[shard].write(example.SerializeToString())
    return paths

In [43]:
(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()
train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.2, random_state=42)

training_set = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(buffer_size=len(train_images))
validation_set = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
test_set = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

train_filepaths = write_multiple_tfrecords('fashion_mnist_tfrecords', 'train', training_set, 10)
validation_filepaths = write_multiple_tfrecords('fashion_mnist_tfrecords', 'valid', validation_set, 10)
test_filepaths = write_multiple_tfrecords('fashion_mnist_tfrecords', 'test', test_set, 10)
