In [17]:
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.datasets import mnist
from sklearn.preprocessing import OneHotEncoder
import cv2

import numpy as np


#读取数据
(X_train, y_train), (X_test, y_test) = mnist.load_data()


# 转换数据维度
X_train = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB) for i in X_train]
X_train= np.concatenate([arr[np.newaxis] for arr in X_train]).astype('float32')
X_test = [cv2.cvtColor(cv2.resize(i, (48, 48)), cv2.COLOR_GRAY2RGB) for i in X_test]
X_test = np.concatenate([arr[np.newaxis] for arr in X_test]).astype('float32')

# X_train = X_train.reshape(-1, 28, 28, 1).astype('float32')
# X_test = X_test.reshape(-1, 28, 28, 1).astype('float32')

X_train /= 255
X_test /= 255

one_hot_encoder = OneHotEncoder()
y_train_one_hot = one_hot_encoder.fit_transform(y_train.reshape(-1,1))
y_test_one_hot = one_hot_encoder.fit_transform(y_test.reshape(-1,1))


# 建立VGG模型
# include_top=False 删除输出层
model_vgg = VGG16(include_top=False, weights='imagenet', input_shape=(48,48,3))

for layer in model_vgg.layers:
    # 固定vgg的权重
    layer.trainable = False
    
model = Flatten(name="flatten")(model_vgg.output)
model = Dense(4096, activation='relu', name='fc1')(model)
model = Dense(4096, activation='relu', name='fc2')(model)
model = Dropout(0.5)(model)
model = Dense(10, activation="softmax")(model)

model_vgg_mnist = Model(inputs=model_vgg.input, outputs=model, name='vgg16')
model_vgg_mnist.summary()


sgd = SGD(lr=0.05, decay=1e-5)
model_vgg_mnist.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
model_vgg_mnist.fit(X_train, y_train_one_hot, validation_data=(X_test, y_test_one_hot), epochs=20, batch_size=50)


_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         (None, 48, 48, 3)         0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 48, 48, 64)        1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 48, 48, 64)        36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 24, 24, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 24, 24, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 24, 24, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 12, 12, 128)       0         
__________

KeyboardInterrupt: 