In [1]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import optimizers
from termcolor import colored

import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Flatten, Dense, Activation

from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD, Adam

from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

In [2]:
def get_mnist_ds():
  (train_validation_ds, test_ds), ds_info = tfds.load(
                                                      name='mnist',
                                                      shuffle_files=True,
                                                      as_supervised=True,
                                                      split=['train', 'test'],
                                                      with_info=True)
  
  n_train_validation = ds_info.splits['train'].num_examples # trainset의 데이터 수
  train_ratio = 0.8
  n_train = int(n_train_validation * train_ratio)
  n_validation = n_train_validation - n_train
  
  train_ds = train_validation_ds.take(n_train)
  remaining_ds = train_validation_ds.skip(n_train)
  validation_ds = remaining_ds.take(n_validation)
  
  return train_ds, validation_ds, test_ds

In [3]:
def normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE):
  global train_ds, validation_ds, test_ds
  
  def stnd(images, labels):
    images = tf.cast(images, tf.float32) / 255.
    return (images, labels)
  
  train_ds = train_ds.map(stnd).shuffle(1000).batch(TRAIN_BATCH_SIZE)
  validation_ds = validation_ds.map(stnd).batch(TEST_BATCH_SIZE)
  test_ds = test_ds.map(stnd).batch(TEST_BATCH_SIZE)

In [4]:
class MNIST_Classifier(Model):
  def __init__(self):
    super(MNIST_Classifier, self).__init__()
    
    self.flatten = Flatten()
    self.d1 = Dense(units=64, activation='relu')
    self.d2 = Dense(units=10, activation='softmax')
    
  def call(self, x):
    x = self.flatten(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

In [5]:
model = MNIST_Classifier()

In [6]:
def load_metrics():
  global train_loss, train_acc
  global validation_loss, validation_acc
  global test_loss, test_acc
  
  train_loss = Mean()
  validation_loss = Mean()
  test_loss = Mean()
  
  train_acc = SparseCategoricalAccuracy()
  validation_acc = SparseCategoricalAccuracy()
  test_acc = SparseCategoricalAccuracy()

In [7]:
@tf.function
def trainer():
  global train_ds, model, loss_object, optimizer
  global train_loss, train_acc
  
  for images, labels in train_ds:
    with tf.GradientTape() as tape:
      predictions = model(images)
      loss = loss_object(labels, predictions)
      
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_acc(labels, predictions)

@tf.function
def validation():
  global validation_ds, model, loss_object
  global validation_loss, validation_acc
  
  for images, labels in validation_ds:
    predictions = model(images)
    loss = loss_object(labels, predictions)
    
    validation_loss(loss)
    validation_acc(labels, predictions)

# @tf.function # print()를 호출하거나 debuging을 할 때에는 @tf.function은 쓰지 말 것.
def tester():
  global test_ds, model, loss_object
  global test_loss, test_acc
  
  for images, labels in test_ds:
    predictions = model(images)
    loss = loss_object(labels, predictions)
    
    test_loss(loss)
    test_acc(labels, predictions)
  
  print('\n=====TEST RESULT=====\n')
  template = 'Test Loss: {:.4f}\t Test Accuracy: {:.2f}%'
  print(template.format(test_loss.result(), test_acc.result() * 100))


def train_repoter():
  global epoch
  global train_loss, train_acc
  global validation_loss, validation_accc
  
  print(colored('Epoch', 'red', 'on_white'), epoch + 1)
  
  template = 'Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\nValidation Loss: {:.4f}\t Validation Accuracy: {:.2f}%\n'
  print(template.format(train_loss.result(), train_acc.result() * 100,
                        validation_loss.result(), validation_acc.result() * 100))
  
  train_loss.reset_states()
  train_acc.reset_states()
  validation_loss.reset_states()
  validation_acc.reset_states()

EPOCHS = 15
LR = 0.005

TRAIN_BATCH_SIZE = 16
TEST_BATCH_SIZE = 32

train_ds, validation_ds, test_ds = get_mnist_ds()
normalization(TRAIN_BATCH_SIZE, TEST_BATCH_SIZE)

model = MNIST_Classifier()

loss_object = SparseCategoricalCrossentropy()
optimizer = SGD(learning_rate=LR)

load_metrics()

for epoch in range(EPOCHS):
  trainer()
  validation()
  train_repoter()

tester()

[47m[31mEpoch[0m 1
Train Loss: 0.7206	 Train Accuracy: 81.85%
Validation Loss: 0.4177	 Validation Accuracy: 89.05%

[47m[31mEpoch[0m 2
Train Loss: 0.3656	 Train Accuracy: 89.82%
Validation Loss: 0.3445	 Validation Accuracy: 90.42%

[47m[31mEpoch[0m 3
Train Loss: 0.3131	 Train Accuracy: 91.18%
Validation Loss: 0.3086	 Validation Accuracy: 91.53%

[47m[31mEpoch[0m 4
Train Loss: 0.2829	 Train Accuracy: 91.92%
Validation Loss: 0.2870	 Validation Accuracy: 92.23%

[47m[31mEpoch[0m 5
Train Loss: 0.2612	 Train Accuracy: 92.56%
Validation Loss: 0.2658	 Validation Accuracy: 92.73%

[47m[31mEpoch[0m 6
Train Loss: 0.2435	 Train Accuracy: 93.17%
Validation Loss: 0.2517	 Validation Accuracy: 93.12%

[47m[31mEpoch[0m 7
Train Loss: 0.2285	 Train Accuracy: 93.60%
Validation Loss: 0.2401	 Validation Accuracy: 93.40%

[47m[31mEpoch[0m 8
Train Loss: 0.2153	 Train Accuracy: 93.96%
Validation Loss: 0.2316	 Validation Accuracy: 93.62%

[47m[31mEpoch[0m 9
Train Loss: 0.2040	 Train 