# Keras訓練模型的方法  
在之前，已經看過了使用`fit()`訓練模型的方法，假如資料龐大，無法將全部資料放進GPU RAM中，可以使用`tf.Data`與`Sequence 類`來解決。  
在Keras中也有另外訓練model的方式`自定義`以及`train_on_batch`，`train_on_batch`幾乎與`fit`差不多，也是透過放入一個一個batch來做訓練。  
在TF2.2以上還可用繼承`Model 類`修改當中的`train_step`來自訂訓練。  
此次，是使用`自定義`來作範例。
詳細資料查閱[Writing a training loop from scratch](https://keras.io/guides/writing_a_training_loop_from_scratch/)

In [1]:
# 載入所需lib
import numpy as np
import math
import time
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
print('TensorFlow version:', tf.__version__)

TensorFlow version: 2.2.0


自訂義訓練需要使用`GradientTape`，在`GradientTape`中呼叫model，就可以根據loss與optimizer調整model.trainable的參數。

In [2]:
def get_model():
    inputs = tf.keras.Input(shape=(28, 28, 1))

    # model layer
    conv_1 = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')
    max_pool_1 = tf.keras.layers.MaxPooling2D()
    conv_2 = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')
    max_pool_2 = tf.keras.layers.MaxPooling2D()
    flatten = tf.keras.layers.Flatten()
    drop = tf.keras.layers.Dropout(0.5)
    output = tf.keras.layers.Dense(10, activation='softmax')

    # path
    x = conv_1(inputs)
    x = max_pool_1(x)
    x = conv_2(x)
    x = max_pool_2(x)
    x = flatten(x)
    x = drop(x)
    x = output(x)

    model = tf.keras.Model(inputs=inputs, outputs=x)
    model.summary()
    return model

In [3]:
#download MNIST dataset and preprocessing
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train.astype('float32') / 255
y_train = y_train.astype('float32')
x_test = x_test.astype('float32') / 255
y_test = y_test.astype('float32')

x_train = np.expand_dims(x_train, -1)
y_train = np.expand_dims(y_train, -1)
x_test = np.expand_dims(x_test, -1)
y_test = np.expand_dims(y_test, -1)

在`fit`中，有些參數需要設定，`batch_size`、`epochs`等等，自定義時候也需要進行處理。

In [4]:
model = get_model()
# parameter init
epochs = 10
batch_size = 128
batch_of_epoch = math.ceil(len(x_train) / batch_size)

# optimizer
opt = tf.keras.optimizers.Adam()
# loss
loss_fu = tf.keras.losses.SparseCategoricalCrossentropy()

# accuracy
acc = tf.keras.metrics.SparseCategoricalAccuracy()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1600)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1600)              0     

## 開始訓練循環

In [5]:
for epoch in range(epochs):
    start_time = time.time()
    #印出第幾個epoch
    print('start of epoch {}'.format(epoch + 1))
    
    #每個batch進行訓練
    for idx in range(batch_of_epoch):
        x_batch = x_train[idx * batch_size : (idx + 1) * batch_size]
        y_batch = y_train[idx * batch_size : (idx + 1) * batch_size]
        
        #使用GradientTape
        with tf.GradientTape() as tape:
            # 某些layer，training與testing有所不同，所以訓練時代入training=True參數
            pred = model(x_batch, training=True)
            loss_value = loss_fu(y_batch, pred)
        # 計算所有trainable weights對應的gradients
        grads = tape.gradient(loss_value, model.trainable_weights)
        # 用optimizer調整trainable weights
        opt.apply_gradients(zip(grads, model.trainable_weights))
        
        #更新 accuracy metric
        acc.update_state(y_batch, pred)
    
    # 訓練完一個epoch，印出accuracy
    print('Training accuracy on {} epoch:'.format(epoch + 1,))
    print('loss:{:7.4f}, accuracy:{:7.4f}%'.format(loss_value, acc.result() * 100))
    
    # 清空accuracy state
    acc.reset_states()
    print('Use times:{:4.4f}'.format(time.time() - start_time))

start of epoch 1
Training accuracy on 1 epoch:
loss: 0.2590, accuracy:90.0100%
Use times:6.2576
start of epoch 2
Training accuracy on 2 epoch:
loss: 0.1912, accuracy:96.7550%
Use times:4.8390
start of epoch 3
Training accuracy on 3 epoch:
loss: 0.1875, accuracy:97.4967%
Use times:4.9229
start of epoch 4
Training accuracy on 4 epoch:
loss: 0.1804, accuracy:97.9233%
Use times:5.0508
start of epoch 5
Training accuracy on 5 epoch:
loss: 0.1787, accuracy:98.1167%
Use times:4.6721
start of epoch 6
Training accuracy on 6 epoch:
loss: 0.1766, accuracy:98.3583%
Use times:4.5598
start of epoch 7
Training accuracy on 7 epoch:
loss: 0.1751, accuracy:98.4017%
Use times:4.4798
start of epoch 8
Training accuracy on 8 epoch:
loss: 0.1742, accuracy:98.6217%
Use times:4.6399
start of epoch 9
Training accuracy on 9 epoch:
loss: 0.1723, accuracy:98.6700%
Use times:4.7882
start of epoch 10
Training accuracy on 10 epoch:
loss: 0.1873, accuracy:98.7600%
Use times:4.5140


原本一個epoch只約兩秒鐘，但是自定義卻要4秒，這時候可以靠`tf.function`來提升性能。
`tf.function`是將某些步驟轉換成static graph。  
改寫一下training loop。

In [6]:
# clear session
tf.keras.backend.clear_session()

model = get_model()
# 定義train function並且用tf.function裝飾
@tf.function
def train_data(x, y):
    with tf.GradientTape() as tape:
        pred = model(x_batch, training=True)
        loss_value = loss_fu(y_batch, pred)
    grads = tape.gradient(loss_value, model.trainable_weights)
    opt.apply_gradients(zip(grads, model.trainable_weights))
    acc.update_state(y_batch, pred)
    
    return loss_value

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1600)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1600)              0     

In [7]:
for epoch in range(epochs):
    start_time = time.time()
    #印出第幾個epoch
    print('start of epoch {}'.format(epoch + 1))
    
    #每個batch進行訓練
    for idx in range(batch_of_epoch):
        x_batch = x_train[idx * batch_size : (idx + 1) * batch_size]
        y_batch = y_train[idx * batch_size : (idx + 1) * batch_size]
        
        loss_value = train_data(x_batch, y_batch)
    
    # 訓練完一個epoch，印出accuracy
    print('Training accuracy on {} epoch:'.format(epoch + 1,))
    print('loss:{:7.4f}, accuracy:{:7.4f}%'.format(loss_value, acc.result() * 100))
    
    # 清空accuracy state
    acc.reset_states()
    print('Use times:{:4.4f}'.format(time.time() - start_time))

start of epoch 1
Training accuracy on 1 epoch:
loss: 1.4528, accuracy:98.9200%
Use times:1.1737
start of epoch 2
Training accuracy on 2 epoch:
loss: 0.3234, accuracy:99.9633%
Use times:0.7426
start of epoch 3
Training accuracy on 3 epoch:
loss: 0.2433, accuracy:99.9933%
Use times:0.7224
start of epoch 4
Training accuracy on 4 epoch:
loss: 0.1452, accuracy:99.9917%
Use times:0.7287
start of epoch 5
Training accuracy on 5 epoch:
loss: 0.0395, accuracy:99.9983%
Use times:0.7254
start of epoch 6
Training accuracy on 6 epoch:
loss: 0.0873, accuracy:99.9967%
Use times:0.7274
start of epoch 7
Training accuracy on 7 epoch:
loss: 0.1008, accuracy:99.9900%
Use times:0.7199
start of epoch 8
Training accuracy on 8 epoch:
loss: 0.0415, accuracy:99.9967%
Use times:0.7356
start of epoch 9
Training accuracy on 9 epoch:
loss: 0.0051, accuracy:99.9983%
Use times:0.7471
start of epoch 10
Training accuracy on 10 epoch:
loss: 0.0130, accuracy:99.9983%
Use times:0.7226


可以看到使用`tf.function`將一個epoch訓練時間縮小到不到1秒，但是如果是一些很複雜的操作，使用`tf.function`可能就沒有這麼好的效果。

## **總結**  
自定義訓練循環與`fit`大同小異，透過`tf.GradientTape`來算出gradient，再透過optimizer對model的權重進行更新。  
在TF2.2之後，若想要有`fit`的一些優點，以及`自定義訓練`，可以改寫`Model類`的`train_step` method，預計在之後的文章說明。