In [1]:
######################################## Library for CNN ########################################
from tensorflow.keras import Model, models
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import *
from tensorflow.keras.applications import vgg16, resnet50, mobilenet, densenet, xception
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.utils import plot_model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow import keras


######################################## General Library ########################################
import os
import copy
import time
import json as js
import numpy as np


######################################### Plot Library #########################################
from bokeh.plotting import figure, output_file, show
from bokeh.io import export_png
from bokeh.models import ColumnDataSource
from bokeh.palettes import cividis
from bokeh.layouts import gridplot

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
BATCH_SIZE = 32
EPOCH = 30
IMAGE_SIZE = 100
LEARNING_RATE = 1e-4
PATIENCE = 10
CUT_AT = 42
UNFREEZE_FROM = 0
BASE_ARCH = 'Xception'

TRAINING_DATASET = os.path.join('japanese_fagaceae_dataset_cropped_'+str(IMAGE_SIZE), 'train')
VALIDATION_DATASET = os.path.join('japanese_fagaceae_dataset_cropped_'+str(IMAGE_SIZE), 'validation')

CLASSES = len(os.listdir(TRAINING_DATASET))

OUTPUT_PATH = ['result', str(IMAGE_SIZE), BASE_ARCH]

In [3]:
if BASE_ARCH == 'VGG16':
    preprocessing_function = vgg16.preprocess_input
elif BASE_ARCH == 'ResNet50':
    preprocessing_function = resnet50.preprocess_input
elif BASE_ARCH =='MobileNet':
    preprocessing_function = mobilenet.preprocess_input
elif BASE_ARCH =='DenseNet121':
    preprocessing_function = densenet.preprocess_input
elif BASE_ARCH =='Xception':
    preprocessing_function = xception.preprocess_input
else:
    pass

train_datagen = ImageDataGenerator(
    preprocessing_function=preprocessing_function
)
validation_datagen = ImageDataGenerator(
    preprocessing_function=preprocessing_function
)

In [4]:
train_generator = train_datagen.flow_from_directory(
    TRAINING_DATASET,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

validation_generator = validation_datagen.flow_from_directory(
    VALIDATION_DATASET,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

Found 147600 images belonging to 9 classes.
Found 22140 images belonging to 9 classes.


In [5]:
MODEL_WEIGHT = copy.copy(OUTPUT_PATH)
MODEL_WEIGHT.append('['+BASE_ARCH+'_'+str(IMAGE_SIZE)+'] weight.h5')

Checkpoint = ModelCheckpoint(os.path.sep.join(MODEL_WEIGHT), 
                               monitor='val_acc',
                               save_best_only=True
                              )
EarlyStopping = EarlyStopping(monitor="val_acc", 
                                patience=PATIENCE, 
                                mode="max"
                               )
ReduceLR = ReduceLROnPlateau(monitor="val_acc", 
                                 factor=0.6, 
                                 patience=5, 
                                 min_lr=1e-6, 
                                 verbose=1, 
                                 mode="max"
                                )

In [6]:
# load model
if BASE_ARCH == 'VGG16':
    model = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE,IMAGE_SIZE,3))
elif BASE_ARCH == 'ResNet50':
    model = resnet50.ResNet50(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE,IMAGE_SIZE,3))
elif BASE_ARCH == 'MobileNet':
    model = mobilenet.MobileNet(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE,IMAGE_SIZE,3))
elif BASE_ARCH == 'DenseNet121':
    model = densenet.DenseNet121(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE,IMAGE_SIZE,3))
elif BASE_ARCH == 'Xception':
    model = xception.Xception(weights='imagenet', include_top=False, input_shape=(IMAGE_SIZE,IMAGE_SIZE,3))
else:
    pass

for layer in model.layers:
    layer.trainable = False
    
if CUT_AT==-1:
    x = model.output
else:
    x = model.layers[CUT_AT].output
    
x = GlobalAveragePooling2D()(x)
x = Dense(512, activation='relu')(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)
predictions = Dense(CLASSES, activation='softmax')(x)

# instantiate new model
jf_model = Model(inputs=model.input, outputs=predictions, name='japanese_fagaceae_model')

# unfreeze selected layer
for layer in jf_model.layers[UNFREEZE_FROM:]:
    layer.trainable = True

# optimizer
opt=Adam(learning_rate=LEARNING_RATE)

# compile model
jf_model.compile(
    loss='categorical_crossentropy',
    optimizer=opt, 
    metrics=['accuracy']
)

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


In [7]:
for i, layer in enumerate(jf_model.layers):
    print(str(i), layer, layer.trainable)

0 <tensorflow.python.keras.engine.input_layer.InputLayer object at 0x000001B875F6D048> True
1 <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x000001B870E75388> True
2 <tensorflow.python.keras.layers.normalization.BatchNormalization object at 0x000001B8702FD048> True
3 <tensorflow.python.keras.layers.core.Activation object at 0x000001B875FAF9C8> True
4 <tensorflow.python.keras.layers.convolutional.Conv2D object at 0x000001B875F79748> True
5 <tensorflow.python.keras.layers.normalization.BatchNormalization object at 0x000001B87609CFC8> True
6 <tensorflow.python.keras.layers.core.Activation object at 0x000001B8760C5F88> True
7 <tensorflow.python.keras.layers.convolutional.SeparableConv2D object at 0x000001B8761A8408> True
8 <tensorflow.python.keras.layers.normalization.BatchNormalization object at 0x000001B876264E08> True
9 <tensorflow.python.keras.layers.core.Activation object at 0x000001B876281348> True
10 <tensorflow.python.keras.layers.convolutional.SeparableConv2D obj

In [8]:
PLOT_ARCH = copy.copy(OUTPUT_PATH)
PLOT_ARCH.append('['+BASE_ARCH+'_'+str(IMAGE_SIZE)+'] model_arch.png')

plot_model(
    jf_model,
    to_file=os.path.sep.join(PLOT_ARCH),
    show_shapes=True,
    show_layer_names=False
)

print('[SAVED]', PLOT_ARCH)

[SAVED] ['result', '100', 'Xception', '[Xception_100] model_arch.png']


In [9]:
tic = time.time()

history = jf_model.fit_generator(
    train_generator,
    epochs=EPOCH,
    validation_data=validation_generator,
    callbacks=[Checkpoint, EarlyStopping, ReduceLR],
    verbose=1
)

toc = time.time()

print('\n\ntraining speed = %.2f seconds' % (toc-tic))
print('training speed = %.2f minutes' % ((toc-tic)/60))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 00017: ReduceLROnPlateau reducing learning rate to 5.999999848427251e-05.
Epoch 18/20
Epoch 19/20
Epoch 20/20


training speed = 8421.03 seconds
training speed = 140.35 minutes


In [10]:
MODEL_LABEL = copy.copy(OUTPUT_PATH)
MODEL_LABEL.append('['+BASE_ARCH+'_'+str(IMAGE_SIZE)+'] weight.txt')

label2index = train_generator.class_indices
idx2label = dict((v, [k]) for k, v in label2index.items())
with open(os.path.sep.join(MODEL_LABEL), 'w') as f:
    js.dump(idx2label, f)
    
print('[SAVED]', MODEL_LABEL)

[SAVED] ['result', '100', 'Xception', '[Xception_100] weight.txt']


In [11]:
PLOT_TRAIN_HISTORY = copy.copy(OUTPUT_PATH)
PLOT_TRAIN_HISTORY.append('['+BASE_ARCH+'_'+str(IMAGE_SIZE)+'] training_history.html')
output_file(os.path.sep.join(PLOT_TRAIN_HISTORY))

palette =cividis(4)

acc_source = ColumnDataSource(dict(
    x=list(np.arange(1, len(history.history['acc'])+1)),
    trainY=list(history.history['acc']),
    validationY=list(history.history['val_acc']),
))

loss_source = ColumnDataSource(dict(
    x=list(np.arange(1, len(history.history['acc'])+1)),
    trainY=list(history.history['loss']),
    validationY=list(history.history['val_loss']),
))

plot_acc = figure(title='Model Accuracy',
                  x_axis_label='Epoch',
                  y_axis_label='Accuracy')
plot_acc.axis.axis_label_text_font_size='25pt'
plot_acc.axis.major_label_text_font_size = '15pt'
plot_acc.title.text_font_size ='25pt'

plot_loss = figure(title='Model Loss',
                  x_axis_label='Epoch',
                  y_axis_label='Loss')
plot_loss.axis.axis_label_text_font_size='25pt'
plot_loss.axis.major_label_text_font_size = '15pt'
plot_loss.title.text_font_size ='25pt'

plot_acc.line(x='x', y='trainY',
              color=palette[0],
              legend_label='Train Accuracy',
              source=acc_source)
plot_acc.circle(x='x', y='trainY',
                color=palette[1],
                legend_label='Train Accuracy',
                source=acc_source)
plot_acc.line(x='x', y='validationY',
              color=palette[2],
              legend_label='Validation Accuracy',
              source=acc_source)
plot_acc.circle(x='x', y='validationY',
                color=palette[3],
                legend_label='Validation Accuracy',
                source=acc_source)

plot_loss.line(x='x', y='trainY',
               color=palette[0],
               legend_label='Train Loss',
               source=loss_source)
plot_loss.circle(x='x', y='trainY',
                 color=palette[1],
                 legend_label='Train Loss',
                 source=loss_source)
plot_loss.line(x='x', y='validationY',
               color=palette[2],
               legend_label='Validation Loss',
               source=loss_source)
plot_loss.circle(x='x', y='validationY',
                 color=palette[3],
                 legend_label='Validation Loss',
                 source=loss_source)

plot_acc.legend.location = "top_left"
plot_acc.legend.label_text_font_size = "15pt"
plot_acc.legend.glyph_height = 30
plot_acc.legend.glyph_width = 30
plot_loss.legend.label_text_font_size = "15pt"
plot_loss.legend.glyph_height = 30
plot_loss.legend.glyph_width = 30

grid = gridplot([plot_acc, plot_loss], ncols=2)

show(grid)

print('[SAVED]', PLOT_TRAIN_HISTORY)

[SAVED] ['result', '100', 'Xception', '[Xception_100] training_history.html']
