Permalink
Switch branches/tags
Nothing to show
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
98 lines (80 sloc) 3.03 KB
from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential, Model
from keras.layers import Dropout, Flatten, Dense, Input
from keras.utils import to_categorical
import numpy as np
# path to the model weights files.
top_model_weights_path = 'fc_model.h5'
full_model_weights_path = 'full_model.h5'
# dimensions of our images.
img_width, img_height = 299, 299
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = 1808
nb_validation_samples = 192
epochs = 50
batch_size = 16
# build the InceptionV3 network
input_tensor = Input(shape=(img_height, img_width, 3))
base_model = applications.InceptionV3(weights='imagenet', include_top=False, input_tensor=input_tensor)
print('Model loaded.')
# build a classifier model to put on top of the convolutional model
top_model = Sequential()
top_model.add(Flatten(input_shape=base_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(5, activation='softmax'))
# print top_model.summary()
# note that it is necessary to start with a fully-trained
# classifier, including the top classifier,
# in order to successfully do fine-tuning
top_model.load_weights(top_model_weights_path)
# add the model on top of the convolutional base
model = Model(inputs=base_model.input, outputs=top_model(base_model.output))
# set the first 25 layers (up to the last conv block)
# to non-trainable (weights will not be updated)
for layer in model.layers[:len(base_model.layers)-5]:
layer.trainable = False
# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
# prepare data augmentation configuration
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='categorical')
# fine-tune the model
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples)
# print model.predict_generator(validation_generator, steps=1)
model.save_weights(full_model_weights_path)
from keras.utils import plot_model
plot_model(model, to_file='model.png')
# from IPython.display import SVG
# from keras.utils.vis_utils import model_to_dot
#
# SVG(model_to_dot(model).create(prog='dot', format='svg'))