# 16강
- Mnist Classification

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Flatten, Dense, Activation
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.optimizers import SGD, Adam

from tensorflow.keras.metrics import Mean, SparseCategoricalAccuracy

import numpy as np
import matplotlib.pyplot as plt
from termcolor import colored

def get_mnist_ds():
  (train_validation_ds, test_ds), ds_info = tfds.load(name = 'mnist',
                                                      shuffle_files = True,
                                                      as_supervised = True,
                                                      split = ['train', 'test'],
                                                      with_info = True
                                                      )
  n_train_validation = ds_info.splits['train'].num_examples

  train_ratio = 0.8
  n_train = int(n_train_validation * train_ratio)
  n_validation = n_train_validation - n_train

  train_ds = train_validation_ds.take(n_train)
  remaining_ds = train_validation_ds.skip(n_train)
  validation_ds = remaining_ds.take(n_validation)

  return train_ds, validation_ds, test_ds

def standardization(train_batch_size, test_batch_size):
  global train_ds, validation_ds, test_ds

  def stnd(images, labels):
    images = tf.cast(images, tf.float32) / 255.
    return [images, labels]

  train_ds = train_ds.map(stnd).shuffle(1000).batch(train_batch_size)
  validation_ds = validation_ds.map(stnd).batch(test_batch_size)
  test_ds = test_ds.map(stnd).batch(test_batch_size)

class MNIST_Classifier(Model):
  def __init__(self):
    super(MNIST_Classifier, self).__init__()

    self.Flatten = Flatten()
    self.d1 = Dense(64, activation = 'relu')
    self.d2 = Dense(10, activation = 'softmax')

  def call(self, x):
    x = self.Flatten(x)
    x = self.d1(x)
    x = self.d2(x)
    return x

def load_metrics():
  global train_loss, train_acc
  global validation_loss, validation_acc
  global test_loss, test_acc

  train_loss = Mean()
  validation_loss = Mean()
  test_loss = Mean()

  train_acc = SparseCategoricalAccuracy()
  validation_acc = SparseCategoricalAccuracy()
  test_acc = SparseCategoricalAccuracy()

@tf.function
def trainer():
  global train_ds, model, loss_object, optimizer
  global train_loss, train_acc

  for images, labels in train_ds:
    with tf.GradientTape() as tape:
      predictions = model(images)
      loss = loss_object(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_acc(labels, predictions)

@tf.function
def validation():
  global validation_ds, model, loss_object
  global validation_loss, validation_acc

  for images, labels in validation_ds:
    predictions = model(images)
    loss = loss_object(labels, predictions)

    validation_loss(loss)
    validation_acc(labels, predictions)

#@tf.function -> print문이 있으면 써주지 않습니다
def tester():
  global test_ds, model, loss_object
  global test_loss, test_acc

  for images, labels in test_ds:
    predictions = model(images)
    loss = loss_object(labels, predictions)

    test_loss(loss)
    test_acc(labels, predictions)
  
  template ='Test Loss: {:.4f}\t Test Accuracy: {:.2f}%\n'
  print(template.format(test_loss.result(), test_acc.result()*100))


def train_report():
  global epoch
  global train_loss, train_loss
  global vaidation_loss, validation_acc

  print(colored('Epoch: ', 'red', 'on_white'), epoch +1)
  template ='Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\n \t Validation Loss: {:.4f}\t Validation Accuracy: {:.2f}%\n'
  print(template.format(train_loss.result(), train_acc.result()*100,
                        validation_loss.result(), validation_acc.result()*100))
  
  train_loss.reset_states()
  train_acc.reset_states()
  validation_loss.reset_states()
  validation_acc.reset_states()

EPOCHS = 10
LR = 0.001
train_batch_size = 16
test_batch_size = 32

train_ds, validation_ds, test_dst = get_mnist_ds()
standardization(train_batch_size, test_batch_size)

model = MNIST_Classifier()
loss_object = SparseCategoricalCrossentropy()
optimizer = SGD(learning_rate = LR)

load_metrics()

for epoch in range(EPOCHS):
  trainer()
  validation()
  train_report()

tester()

[47m[31mEpoch: [0m 1
Train Loss: 1.5173	 Train Accuracy: 60.75%
 	 Validation Loss: 0.9592	 Validation Accuracy: 78.58%

[47m[31mEpoch: [0m 2
Train Loss: 0.7531	 Train Accuracy: 82.82%
 	 Validation Loss: 0.6264	 Validation Accuracy: 85.02%

[47m[31mEpoch: [0m 3
Train Loss: 0.5570	 Train Accuracy: 86.31%
 	 Validation Loss: 0.5110	 Validation Accuracy: 86.89%

[47m[31mEpoch: [0m 4
Train Loss: 0.4743	 Train Accuracy: 87.72%
 	 Validation Loss: 0.4522	 Validation Accuracy: 88.12%

[47m[31mEpoch: [0m 5
Train Loss: 0.4276	 Train Accuracy: 88.65%
 	 Validation Loss: 0.4158	 Validation Accuracy: 88.92%

[47m[31mEpoch: [0m 6
Train Loss: 0.3970	 Train Accuracy: 89.21%
 	 Validation Loss: 0.3909	 Validation Accuracy: 89.42%

[47m[31mEpoch: [0m 7
Train Loss: 0.3751	 Train Accuracy: 89.65%
 	 Validation Loss: 0.3723	 Validation Accuracy: 89.78%

[47m[31mEpoch: [0m 8
Train Loss: 0.3582	 Train Accuracy: 90.07%
 	 Validation Loss: 0.3576	 Validation Accuracy: 90.17%

[47m[3