Skip to content

Commit

Permalink
fix vgg finetuning example
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Oct 12, 2016
1 parent d0e3699 commit 3d4bce7
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions examples/images/vgg_network_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
last layer (softmax) that will be retrained to match the new task (finetuning).
'''
import tflearn
from tflearn.layers.estimator import regression
from tflearn.data_preprocessing import ImagePreprocessing
import os


def vgg16(num_class, placeholderX=None):
x = tflearn.input_data(shape=[None, 224, 224, 3], name='input',
placeholder=placeholderX)
def vgg16(input, num_class):

x = tflearn.conv_2d(x, 64, 3, activation='relu', scope='conv1_1')
x = tflearn.conv_2d(input, 64, 3, activation='relu', scope='conv1_1')
x = tflearn.conv_2d(x, 64, 3, activation='relu', scope='conv1_2')
x = tflearn.max_pool_2d(x, 2, strides=2, name='maxpool1')

Expand Down Expand Up @@ -42,7 +40,8 @@ def vgg16(num_class, placeholderX=None):
x = tflearn.fully_connected(x, 4096, activation='relu', scope='fc7')
x = tflearn.dropout(x, 0.5, name='dropout2')

x = tflearn.fully_connected(x, num_class, activation='softmax', scope='fc8', restore=False)
x = tflearn.fully_connected(x, num_class, activation='softmax', scope='fc8',
restore=False)

return x

Expand All @@ -54,26 +53,38 @@ def vgg16(num_class, placeholderX=None):

from tflearn.data_utils import image_preloader

X, Y = image_preloader(files_list, image_shape=(224, 224), mode='file', categorical_labels=True, normalize=True,
X, Y = image_preloader(files_list, image_shape=(224, 224), mode='file',
categorical_labels=True, normalize=True,
files_extension=['.jpg', '.png'], filter_channel=True)
# or use the mode 'floder'
# X, Y = image_preloader(data_dir, image_shape=(224, 224), mode='folder', categorical_labels=True, normalize=True,
# X, Y = image_preloader(data_dir, image_shape=(224, 224), mode='folder',
# categorical_labels=True, normalize=True,
# files_extension=['.jpg', '.png'], filter_channel=True)

num_classes = 10 # num of your dataset

softmax = vgg16(num_classes)
regression = regression(softmax, optimizer='adam',
loss='categorical_crossentropy',
learning_rate=0.001, restore=False)
# VGG preprocessing
img_prep = ImagePreprocessing()
img_prep.add_featurewise_zero_center(mean=[123.68, 116.779, 103.939],
per_channel=True)
# VGG Network
x = tflearn.input_data(shape=[None, 224, 224, 3], name='input',
data_preprocessing=img_prep)
softmax = vgg16(x, num_classes)
regression = tflearn.regression(softmax, optimizer='adam',
loss='categorical_crossentropy',
learning_rate=0.001, restore=False)

model = tflearn.DNN(regression, checkpoint_path='vgg-finetuning',
max_checkpoints=3, tensorboard_verbose=2, tensorboard_dir="./logs")
max_checkpoints=3, tensorboard_verbose=2,
tensorboard_dir="./logs")

model_file = os.path.join(model_path, "vgg16.tflearn")
model.load(model_file, weights_only=True)

# Start finetuning
model.fit(X, Y, n_epoch=10, validation_set=0.1, shuffle=True,
show_metric=True, batch_size=64, snapshot_epoch=False, snapshot_step=200, run_id='vgg-finetuning')
show_metric=True, batch_size=64, snapshot_epoch=False,
snapshot_step=200, run_id='vgg-finetuning')

model.save('your-task-model-retrained-by-vgg')

0 comments on commit 3d4bce7

Please sign in to comment.