In [5]:
!pip install tensorflow-datasets

[0m

In [3]:
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import tensorflow_datasets as tfds


# To Avoid GPU errors
physical_devices = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [4]:
(ds_train, ds_test), ds_info = tfds.load(
    "mnist",
    split=['train', 'test'],
    shuffle_files=True, 
    as_supervised=True,
    with_info=True
)

In [5]:
def normalize_img(image, label):
    #normalize images
    return tf.cast(image, tf.float32)/255.0, label

In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)

In [7]:
BATCH_SIZE = 64
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
ds_train = ds_train.batch(BATCH_SIZE)
ds_train = ds_train.prefetch(AUTOTUNE)

In [8]:
ds_test = ds_test.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_test = ds_test.batch(128)
ds_train = ds_train.prefetch(AUTOTUNE)

In [9]:
model = keras.Sequential([
    keras.Input((28, 28, 1)),
    layers.Conv2D(32, 3, activation='relu'),
    layers.Flatten(),
    layers.Dense(10, activation='softmax')
])

In [10]:
epochs = 5

In [12]:
optimizer = keras.optimizers.Adam()
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [13]:
acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [17]:
for epoch in range(epochs):
    print('> Starting of Epoch #', epoch)
    for batch_idx, (x_batch, y_batch) in enumerate(ds_train):
        with tf.GradientTape() as tape:
            y_pred = model(x_batch, training=True)
            loss = loss_fn(y_batch, y_pred)
        
        gradients = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))
        acc_metric.update_state(y_batch, y_pred)
    train_acc = acc_metric.result()
    print(f'\t> Accuracy over epoch: {train_acc}')
    acc_metric.reset_states()

> Starting of Epoch # 0
	> Accuracy over epoch: 0.9412999749183655
> Starting of Epoch # 1
	> Accuracy over epoch: 0.9786499738693237
> Starting of Epoch # 2
	> Accuracy over epoch: 0.9842000007629395
> Starting of Epoch # 3
	> Accuracy over epoch: 0.9880499839782715
> Starting of Epoch # 4
	> Accuracy over epoch: 0.9902333617210388


In [21]:
acc_metric.reset_states()
for batch_idx, (x_batch, y_batch) in enumerate(ds_test):
    y_pred = model(x_batch, training=False)
    acc_metric.update_state(y_batch, y_pred)
print(f'Test Accuracy: {acc_metric.result()}')
acc_metric.reset_states()


Test Accuracy: 0.9837999939918518
