### This notebook is following the code from
https://www.tensorflow.org/tutorials/quickstart/advanced

In [1]:
import torch 
import tensorflow as tf
print(tf.__version__)

import matplotlib.pyplot as plt
%matplotlib inline

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Flatten, Dense

2.3.0


In [2]:
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
# gpus = tf.config.experimental.list_physical_devices('GPU')
# if gpus:
#     # Restrict TensorFlow to only allocate 1GB * 2 of memory on the first GPU
#     try:
#         tf.config.experimental.set_virtual_device_configuration(
#             gpus[0],
#             [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024 * 2)])
#         logical_gpus = tf.config.experimental.list_logical_devices('GPU')
#         print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
#     except RuntimeError as e:
#         # Virtual devices must be set before GPUs have been initialized
#         print(e)

In [4]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255., x_test/255.  

x_train.shape, x_test.shape

((60000, 28, 28), (10000, 28, 28))

In [5]:
# add a channel dimension 
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

x_train.shape, x_test.shape

((60000, 28, 28, 1), (10000, 28, 28, 1))

In [6]:
# use tf.data to batch and shuffle the dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [7]:
# build keras model by using the model subclassing 

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10)
        
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.d2(x)
        return x
    
model = MyModel()

In [8]:
loss_obect = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acuracy')


In [9]:
# for printing the results at each epoch
train_loss_mean = tf.keras.metrics.Mean(name='train_loss')
test_loss_mean = tf.keras.metrics.Mean(name='test_loss')

In [10]:
# tf.GradientTape to train the model

@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        preds = model(images, training=True)
        loss = loss_obect(labels, preds)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    
    # below is for the displaying results 
    train_loss_mean(loss)
    train_accuracy(labels, preds)

In [11]:
@tf.function
def test_step(images, labels):
    preds = model(images, training=False)
    t_loss = loss_obect(labels, preds)
    
    # below is for the displaying results
    test_loss_mean(t_loss)
    test_accuracy(labels, preds)

In [12]:
EPOCHS = 15

for epoch in range(EPOCHS):
    # reste the metrics at start of every epoch
    train_loss_mean.reset_states()
    train_accuracy.reset_states()
    
    test_loss_mean.reset_states()
    test_accuracy.reset_states()
    
    for images,labels in train_ds:
        train_step(images, labels)
        
    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)
       
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1, 
                          train_loss_mean.result(),
                          train_accuracy.result() * 100, 
                          test_loss_mean.result(),
                          test_accuracy.result() * 100)) 

Epoch 1, Loss: 0.13962650299072266, Accuracy: 95.79666900634766, Test Loss: 0.06152357906103134, Test Accuracy: 97.91999816894531
Epoch 2, Loss: 0.0445728525519371, Accuracy: 98.625, Test Loss: 0.05542219057679176, Test Accuracy: 98.29999542236328
Epoch 3, Loss: 0.023003609851002693, Accuracy: 99.288330078125, Test Loss: 0.04975861310958862, Test Accuracy: 98.32999420166016
Epoch 4, Loss: 0.015097888186573982, Accuracy: 99.50833129882812, Test Loss: 0.05564732849597931, Test Accuracy: 98.43999481201172
Epoch 5, Loss: 0.009281602688133717, Accuracy: 99.69666290283203, Test Loss: 0.060736045241355896, Test Accuracy: 98.40999603271484
Epoch 6, Loss: 0.00916693452745676, Accuracy: 99.69000244140625, Test Loss: 0.0707729235291481, Test Accuracy: 98.25
Epoch 7, Loss: 0.004770190455019474, Accuracy: 99.83666229248047, Test Loss: 0.08128125965595245, Test Accuracy: 98.19999694824219
Epoch 8, Loss: 0.006122227758169174, Accuracy: 99.77999877929688, Test Loss: 0.06571256369352341, Test Accuracy: