**以轉移學習應用在CNN 手寫辨識**

In [1]:
%env KERAS_BACKEND=tensorflow

env: KERAS_BACKEND=tensorflow


In [2]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

In [3]:
# Keras functions
from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import SGD

# Keras dataset
from keras.datasets import mnist

# Keras utilis function
from keras.utils import np_utils

In [4]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

**1.輸出入資料格式處理**

In [5]:
x_train = x_train.reshape(60000, 28, 28, 1)
x_test = x_test.reshape(10000 , 28, 28, 1)

In [6]:
x_train_01 = x_train[y_train <= 1] #只取0-1的資料
x_test_01 = x_test[y_test <= 1]

In [7]:
#one-hot encoding
y_train_10 = np_utils.to_categorical(y_train, 10)
y_test_10 = np_utils.to_categorical(y_test, 10)

y_train_01 = y_train[y_train <= 1]
y_train_01 = np_utils.to_categorical(y_train_01, 2)

y_test_01 = y_test[y_test <= 1]
y_test_01 = np_utils.to_categorical(y_test_01, 2)

In [10]:
x_train_01.shape, x_test_01.shape

((12665, 28, 28, 1), (2115, 28, 28, 1))

In [9]:
y_train_01.shape, y_test_01.shape

((12665, 2), (2115, 2))

**2.建置CNN神經網路**

經典的 CNN 圖形辨識模型 LeNet-5 是一個由兩層卷積層加三層全連接層所建立的神經網路

- 起始為 3 個 convolutional block

    - 每個 convolutional block 為 1 個 2D Convolution + ReLU + 1 個 2D MaxPooling
    - 2D Convolution 的數量為 32, 64, 128
    - 每個 2D Convolution 的 kernal_size 為 3 或 (3, 3)，padding 使用 same
    - 每個 2D MaxPooling 的 pool_size 為 2 或 (2, 2)，padding 使用 same
- 將輸出結果 Flatten 後，接著兩層全連接層，神經元個數分別為 200 和 10 (數字的類別總數)

In [17]:
con_layer = [Conv2D(32, (3,3), padding = 'same', input_shape = (28,28,1)),
             Activation('relu'),
             MaxPooling2D(pool_size = (2,2)),
                    
             Conv2D(64, (3,3), padding = 'same'),
             Activation('relu'),
             MaxPooling2D(pool_size = (2,2)),
            
             Conv2D(128, (3,3), padding = 'same'),
             Activation('relu'),
             MaxPooling2D(pool_size = (2,2))]

fc_layer = [Flatten(),
            Dense(200),
            Activation('relu'),
            Dense(10),
            Activation('softmax')]

model = Sequential(con_layer + fc_layer)
model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_7 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
activation_11 (Activation)   (None, 28, 28, 32)        0         
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 64)        0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 128)        

In [18]:
model.load_weights('CNN_handwriting_weight.h5')

**3.保留前三層 convolutional layer 並進行轉移學習**

將 MNIST 資料集將僅有 0, 1的部分取出來，透過轉移學習建立一個類似 LeNet-5 的 0, 1 圖形辨識模型。

In [20]:
new_fc_layer = [Flatten(),
            Dense(500),
            Activation('tanh'),
            Dense(2),
            Activation('softmax')]

model_0_1 = Sequential(con_layer + new_fc_layer)
model_0_1.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_7 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
activation_11 (Activation)   (None, 28, 28, 32)        0         
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 64)        0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 128)        

In [21]:
for layer in con_layer:
    layer.trainable = False

In [22]:
model_0_1.summary()

Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_7 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
activation_11 (Activation)   (None, 28, 28, 32)        0         
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 14, 14, 32)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
activation_12 (Activation)   (None, 14, 14, 64)        0         
_________________________________________________________________
max_pooling2d_8 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 7, 7, 128)        

In [23]:
model_0_1.compile(loss='mse', optimizer=SGD(lr=0.1), metrics=['accuracy'])

**4.訓練模型**

In [25]:
model_0_1.fit(x_train_01, y_train_01,batch_size=100, epochs=5)


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.callbacks.History at 0x1e0c46fac88>

In [27]:
score = model_0_1.evaluate(x_test_01, y_test_01)



In [28]:
print('測試資料的 loss:', score[0])
print('測試資料正確率:', score[1])

測試資料的 loss: 0.00028778230123823996
測試資料正確率: 1.0
