[구글 코랩(Colab)에서 실행하기](https://colab.research.google.com/github/lovedlim/tensorflow/blob/main/Part%203/3.12_gradient_tape_model.ipynb)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 케라스의 내장 데이터셋에서 mnist 데이터셋을 로드
mnist = tf.keras.datasets.mnist

# load_data()로 데이터셋을 로드 합니다.
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 데이터 정규화
x_train = x_train / x_train.max()
x_test = x_test / x_test.max()

## 8-3-3. GradientTape

In [None]:
# 모델 정의
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)), 
    tf.keras.layers.Dense(256, activation='relu'), 
    tf.keras.layers.Dense(64, activation='relu'), 
    tf.keras.layers.Dense(32, activation='relu'), 
    tf.keras.layers.Dense(10, activation='softmax'), 
])

# 손실함수 정의
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

# 옵티마이저 정의
optimizer = tf.keras.optimizers.Adam()

In [None]:
# 기록을 위한 Metric 정의
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')

In [None]:
# 배치 생성 함수
def get_batches(x, y, batch_size=32):
    for i in range(int(x.shape[0] // batch_size)):
        x_batch = x[i * batch_size: (i + 1) * batch_size]
        y_batch = y[i * batch_size: (i + 1) * batch_size]
        yield (np.asarray(x_batch), np.asarray(y_batch))

In [None]:
@tf.function
def train_step(images, labels):
    # GradientTape 적용
    with tf.GradientTape() as tape:
        # 예측
        prediction = model(images, training=True)
        # 손실
        loss = loss_function(labels, prediction)
    # 미분 (gradient) 값 계산
    gradients = tape.gradient(loss, model.trainable_variables)
    # optimizer 적용
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # loss, accuracy 계산
    train_loss(loss)
    train_accuracy(labels, prediction)

@tf.function
def valid_step(images, labels):
    # 예측
    prediction = model(images, training=False)    
    # 손실
    loss = loss_function(labels, prediction)

    # loss, accuracy 계산
    valid_loss(loss)
    valid_accuracy(labels, prediction)

In [None]:
# 초기화 코드
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()

# Epoch 반복
for epoch in range(5):
    # batch 별 순회
    for images, labels in get_batches(x_train, y_train):
        # train_step
        train_step(images, labels)    

    for images, labels in get_batches(x_test, y_test):
        # valid_step
        valid_step(images, labels)

    # 결과 출력
    metric_template = 'epoch: {}, loss: {:.4f}, acc: {:.2f}%, val_loss: {:.4f}, val_acc: {:.2f}%'
    print(metric_template.format(epoch+1, train_loss.result(), train_accuracy.result()*100, 
                                 valid_loss.result(), valid_accuracy.result()*100))

epoch: 1, loss: 0.2409, acc: 92.81%, val_loss: 0.1575, val_acc: 95.00%
epoch: 2, loss: 0.1717, acc: 94.87%, val_loss: 0.1387, val_acc: 95.58%
epoch: 3, loss: 0.1363, acc: 95.92%, val_loss: 0.1339, val_acc: 95.88%
epoch: 4, loss: 0.1146, acc: 96.56%, val_loss: 0.1265, val_acc: 96.24%
epoch: 5, loss: 0.0991, acc: 97.01%, val_loss: 0.1256, val_acc: 96.43%
