In [None]:
# 예제 프로그램 15.1-1

import cv2
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.applications.inception_v3 import InceptionV3

width = 84
height = 84
channel = 3

# MNIST 데이터 로딩
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255, x_test / 255  # 입력값 정규화

# 인셉션 입력 배열 초기화
x_train_tuned = np.zeros((x_train.shape[0], height, width, channel), 
                              dtype=np.float32)
# 훈련 데이터를 인셉션 입력크기로 변형 
for i, img in enumerate(x_train):
    img1 = cv2.resize(img, dsize=(width, height), 
                      interpolation=cv2.INTER_LINEAR)
    x_train_tuned[i] = cv2.merge((img1, img1, img1))    

# 인셉션 입력 배열 초기화
x_test_tuned = np.zeros((x_test.shape[0], height, width, channel), 
                             dtype=np.float32)
# 검증 데이터를 인셉션 입력크기로 변형 
for i, img in enumerate(x_test):
    img1 = cv2.resize(img, dsize=(width, height), 
                      interpolation=cv2.INTER_LINEAR)
    x_test_tuned[i] = cv2.merge((img1, img1, img1))    

# 인셉션 모델 불러오기
base_model = InceptionV3(input_shape=(height, width, channel),
                  include_top=False, weights='imagenet')
base_model.summary()
print('\n\n')

# 기반모델의 기존 가중치 사용
for layer in base_model.layers :
    layer.trainable = False

# 기반모델의 마지막 층 결정    
last_layer = base_model.get_layer('mixed8')
last_output = last_layer.output

# 새로운 층 추가
x = keras.layers.Flatten()(last_output)
x = keras.layers.Dense(256, activation='relu')(x)
x = keras.layers.Dense(10, activation='softmax')(x)

# 모델 입출력 결정
model = keras.models.Model(base_model.input, x)

model.compile(optimizer='adam',
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy'])
model.summary()
print('\n\n')

# 학습
history = model.fit(x_train_tuned, y_train, epochs=10, 
               validation_data=(x_test_tuned, y_test))

# 학습 결과 출력
pd.DataFrame(history.history).plot.line(figsize = (5, 3))
plt.xlim(0,10), plt.ylim(0,1)
plt.title('\nLearning results')
plt.show()
