In [1]:
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

tfds.disable_progress_bar()
tf.enable_v2_behavior()

In [None]:
(ds_train, ds_val), ds_info = tfds.load('fashion_mnist', split=['train', 'test'],
                                        shuffle_files=True, as_supervised=True, with_info=True
                                        )

[1mDownloading and preparing dataset fashion_mnist/3.0.0 (download: 29.45 MiB, generated: Unknown size, total: 29.45 MiB) to /root/tensorflow_datasets/fashion_mnist/3.0.0...[0m
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.0.incompleteSCM8UE/fashion_mnist-train.tfrecord
Shuffling and writing examples to /root/tensorflow_datasets/fashion_mnist/3.0.0.incompleteSCM8UE/fashion_mnist-test.tfrecord
[1mDataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/3.0.0. Subsequent calls will reuse this data.[0m


In [None]:
ds_info

tfds.core.DatasetInfo(
    name='fashion_mnist',
    version=3.0.0,
    description='Fashion-MNIST is a dataset of Zalando's article images consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes.',
    homepage='https://github.com/zalandoresearch/fashion-mnist',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{DBLP:journals/corr/abs-1708-07747,
      author    = {Han Xiao and
                   Kashif Rasul and
                   Roland Vollgraf},
      title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
                   Algorithms},
      journal   = {CoRR},
      volume

In [None]:
for sample in ds_train:
  X_sample, y_sample = sample
  break

X_sample.shape

TensorShape([28, 28, 1])

In [None]:
import matplotlib.pyplot as plt
sample_img = tf.reshape(X_sample, [28, 28])

In [None]:
print('Label:', y_sample.numpy())
plt.imshow(sample_img);

### Data Transformations

In [None]:
def normalize(img, label):
  return tf.cast(img, tf.float32) / 255., label

ds_train = ds_train.map(normalize).batch(128)
ds_val = ds_val.map(normalize).batch(128)

### Network Architecture

In [None]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense

In [None]:
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics='acc')
model.summary()

In [None]:
model.fit(ds_train, epochs=50, validation_data=ds_val)