In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from termcolor import colored

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Flatten

from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

In [2]:
def get_minst_df():
    #load data
    (train_validation_ds,test_ds),ds_info = tfds.load(name = "mnist",
                                                      split = ["train","test"],
                                                      shuffle_files=True,
                                                      as_supervised=True, 
                                                      with_info = True)
    
    #get number of train/validation and test data set
    n_train_validation_ds = ds_info.splits["train"].num_examples
    n_test_ds = ds_info.splits["test"].num_examples

    #shuffle before splitting into train and valid set
    train_validation_ds = train_validation_ds.shuffle(1000)

    #split into train and validation dataset
    train_ratio = 0.8
    n_train = int(train_ratio*n_train_validation_ds)
    n_valid = n_train_validation_ds - n_train

    train_ds = train_validation_ds.take(n_train) #train: take first 80 percents of data
    validation_ds = train_validation_ds.skip(n_train).take(n_valid)

    print("train_ds num:", len(train_ds))
    print("validation_ds num:", len(validation_ds))
    print("test_ds num:", len(test_ds))
    
    return (train_ds,validation_ds,test_ds)

In [3]:
def standardization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
    global train_ds,validation_ds,test_ds
    
    #standardize the input data and change datatype
    def stnd(images, labels):
        images = tf.cast(images, tf.float32)/255.
        return [images, labels]
        
    train_ds = train_ds.map(stnd).shuffle(1000).batch(TRAIN_BATCH_SIZE)
    validation_ds = validation_ds.map(stnd).batch(TEST_BATCH_SIZE)
    test_ds = test_ds.map(stnd).batch(TEST_BATCH_SIZE)

In [4]:
#define model
class MINIST_Classifier(Model): #inherit tensorflow model
    def __init__(self):
        super(MINIST_Classifier,self).__init__()
        
        #define ingredients
        self.flatten = Flatten()
        self.layer_1 = Dense(units = 64, activation = "relu")
        self.layer_2 = Dense(units = 10, activation = "softmax")
    
    def call(self,x):
        x = self.flatten(x)
        x = self.layer_1(x)
        x = self.layer_2(x)
        return x
        
def load_metrics():
    global train_loss, train_acc
    global validation_loss, validation_acc
    global test_loss, test_acc
    
    train_loss, validation_loss, test_loss = Mean(), Mean(), Mean()
    train_acc, validation_acc, test_acc = (SparseCategoricalAccuracy(),
                                           SparseCategoricalAccuracy(), 
                                           SparseCategoricalAccuracy())
    
@tf.function
def trainer():
    global train_ds, model, loss_object, optimizer
    global train_acc, train_loss
    
    for images, labels in train_ds:
        with tf.GradientTape() as tape:
            predictions = model(images)
            loss = loss_object(labels, predictions)
            
        gradients = tape.gradient(loss,model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))

        train_loss(loss)
        train_acc(labels, predictions)
        
@tf.function
def validation():
    global validation_ds, model, loss_object
    global validation_acc, validation_loss
    
    for images, labels in validation_ds:
        predictions = model(images)
        loss = loss_object(labels, predictions)

        validation_loss(loss)
        validation_acc(labels, predictions)
        
        
@tf.function
def test():
    global test_ds, model, loss_object
    global test_acc, test_loss
    
    for images, labels in test_ds:
        predictions = model(images)
        loss = loss_object(labels, predictions)

        test_loss(loss)
        test_acc(labels, predictions)
        
        
def reporter():
    global epoch
    global train_acc, train_loss
    global validation_acc, validation_loss
    
    print(colored("EPOCH {}".format(epoch+1), "white","on_cyan"))
    template = "Train Loss:{:.4f}\t Train Acc:{:.2f}%\nValid Loss:{:.4f}\t Valid Acc:{:.2f}%"
    print(template.format(train_loss.result(),
                          train_acc.result()*100,
                          validation_loss.result(),
                          validation_acc.result()*100))
    
    train_acc.reset_states()
    train_loss.reset_states()
    validation_loss.reset_states()
    validation_acc.reset_states()

In [5]:
EPOCH = 30
LR = 0.001
TRAIN_BATCH_SIZE = 16
TEST_BATCH_SIZE = 32

train_ds,validation_ds,test_ds = get_minst_df()
standardization(32,32)

load_metrics()

model = MINIST_Classifier()
loss_object = SparseCategoricalCrossentropy()
optimizer = SGD(learning_rate=LR)

for epoch in range(EPOCH):
    trainer()
    validation()
    reporter()
    
test()



train_ds num: 48000
validation_ds num: 12000
test_ds num: 10000
[46m[37mEPOCH 1[0m
Train Loss:1.8795	 Train Acc:45.05%
Valid Loss:1.4440	 Valid Acc:68.63%
[46m[37mEPOCH 2[0m
Train Loss:1.1655	 Train Acc:75.26%
Valid Loss:0.9581	 Valid Acc:79.76%
[46m[37mEPOCH 3[0m
Train Loss:0.8346	 Train Acc:81.49%
Valid Loss:0.7415	 Valid Acc:83.18%
[46m[37mEPOCH 4[0m
Train Loss:0.6771	 Train Acc:83.94%
Valid Loss:0.6277	 Valid Acc:85.03%
[46m[37mEPOCH 5[0m
Train Loss:0.5881	 Train Acc:85.47%
Valid Loss:0.5610	 Valid Acc:86.05%
[46m[37mEPOCH 6[0m
Train Loss:0.5299	 Train Acc:86.61%
Valid Loss:0.5136	 Valid Acc:86.93%
[46m[37mEPOCH 7[0m
Train Loss:0.4897	 Train Acc:87.27%
Valid Loss:0.4775	 Valid Acc:87.70%
[46m[37mEPOCH 8[0m
Train Loss:0.4612	 Train Acc:87.78%
Valid Loss:0.4531	 Valid Acc:88.05%
[46m[37mEPOCH 9[0m
Train Loss:0.4382	 Train Acc:88.29%
Valid Loss:0.4342	 Valid Acc:88.47%
[46m[37mEPOCH 10[0m
Train Loss:0.4201	 Train Acc:88.72%
Valid Loss:0.4162	 Valid Acc:88

In [7]:
template = "Test Loss:{:.4f}\t Test Acc:{:.2f}%"
print(template.format(test_loss.result(),
                      test_acc.result()*100))

Test Loss:0.2867	 Test Acc:91.99%
