In [1]:
from __future__ import division
import random
import pprint
import sys
import time
import numpy as np
from optparse import OptionParser
import pickle
import re

from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from keras_frcnn import config, data_generators
from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from tensorflow.python.keras.utils import generic_utils

sys.setrecursionlimit(40000)

Using TensorFlow backend.


In [2]:
from keras_frcnn.simple_parser import get_data

In [3]:
# pass the settings from the command line, and persist them in the config object
C = config.Config()

C.use_horizontal_flips = False
C.use_vertical_flips = False
C.rot_90 = False

C.model_path = './model_frcnn.hdf5'
model_path_regex = re.match("^(.+)(\.hdf5)$", C.model_path)

C.num_rois = 32

from keras_frcnn import resnet as nn
C.network = 'resnet50'

C.base_net_weights = nn.get_weight_path()

In [4]:
train_imgs, classes_count, class_mapping = get_data('train_annotate.txt')

Parsing annotation files


In [5]:
class_mapping

{'WBC': 0, 'RBC': 1, 'Platelets': 2}

In [6]:
val_imgs, _, _ = get_data('val_annotate.txt')

Parsing annotation files


In [7]:
if 'bg' not in classes_count:
    classes_count['bg'] = 0
    class_mapping['bg'] = len(class_mapping)

C.class_mapping = class_mapping

inv_map = {v: k for k, v in class_mapping.items()}

print('Training images per class:')
pprint.pprint(classes_count)
print(f'Num classes (including bg) = {len(classes_count)}')

Training images per class:
{'Platelets': 216, 'RBC': 2564, 'WBC': 236, 'bg': 0}
Num classes (including bg) = 4


In [8]:
config_output_filename = 'config.pickle'

In [9]:
with open(config_output_filename, 'wb') as config_f:
    pickle.dump(C, config_f)
    print(f'Config has been written to {config_output_filename}, '
          f'and can be loaded when testing to ensure correct results')

Config has been written to config.pickle, and can be loaded when testing to ensure correct results


In [10]:
random.shuffle(train_imgs)

num_imgs = len(train_imgs)

In [11]:

print(f'Num train samples {len(train_imgs)}')
print(f'Num val samples {len(val_imgs)}')

Num train samples 232
Num val samples 59


In [12]:
data_gen_train = data_generators.get_anchor_gt(train_imgs, classes_count, C,
                                               nn.get_img_output_length,
                                               K.image_data_format(), mode='train')
data_gen_val = data_generators.get_anchor_gt(val_imgs, classes_count, C, nn.get_img_output_length,
                                             K.image_data_format(), mode='val')

In [13]:

if K.image_data_format() == 'channels_first':
    input_shape_img = (3, None, None)
else:
    input_shape_img = (None, None, 3)

In [14]:
input_shape_img

(None, None, 3)

In [15]:
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(None, 4))

In [16]:
shared_layers = nn.nn_base(img_input, trainable=True)

In [17]:
shared_layers

<KerasTensor: shape=(None, None, None, 1024) dtype=float32 (created by layer 'activation_39')>

In [18]:

# define the RPN, built on the base layers
num_anchors = len(C.anchor_box_scales) * len(C.anchor_box_ratios)
rpn = nn.rpn(shared_layers, num_anchors)

In [19]:
rpn

[<KerasTensor: shape=(None, None, None, 9) dtype=float32 (created by layer 'rpn_out_class')>,
 <KerasTensor: shape=(None, None, None, 36) dtype=float32 (created by layer 'rpn_out_regress')>,
 <KerasTensor: shape=(None, None, None, 1024) dtype=float32 (created by layer 'activation_39')>]

In [20]:
classifier = nn.classifier(shared_layers, roi_input, C.num_rois,
                           nb_classes=len(classes_count), trainable=True)

In [21]:
model_rpn = Model(img_input, rpn[:2])
model_classifier = Model([img_input, roi_input], classifier)

In [22]:
# this is a model that holds both the RPN and the classifier,
# used to load/save weights for the models
model_all = Model([img_input, roi_input], rpn[:2] + classifier)

In [23]:
optimizer = Adam(lr=1e-5)
optimizer_classifier = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer,
                  loss=[losses.rpn_loss_cls(num_anchors), losses.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer_classifier, loss=[losses.class_loss_cls,
                                                               losses.class_loss_regr(
                                                                   len(classes_count) - 1)],
                         metrics={f'dense_class_{len(classes_count)}': 'accuracy'})
model_all.compile(optimizer='sgd', loss='mae')

  "The `lr` argument is deprecated, use `learning_rate` instead.")


In [24]:
epoch_length = 1000
num_epochs = int(2)

In [25]:
iter_num = 0

losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []

In [26]:
start_time = time.time()

best_loss = np.Inf

class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')

vis = True

for epoch_num in range(num_epochs):

    progbar = generic_utils.Progbar(epoch_length)
    print(f'Epoch {epoch_num + 1}/{num_epochs}')

    while True:
        if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
            mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(
                rpn_accuracy_rpn_monitor)
            rpn_accuracy_rpn_monitor = []
            print(
                f'Average number of overlapping bounding boxes '
                f'from RPN = {mean_overlapping_bboxes} for {epoch_length} previous iterations')
            if mean_overlapping_bboxes == 0:
                print('RPN is not producing bounding boxes that overlap the ground truth boxes.'
                      ' Check RPN settings or keep training.')

        X, Y, img_data = next(data_gen_train)

        loss_rpn = model_rpn.train_on_batch(X, Y)

        P_rpn = model_rpn.predict_on_batch(X)

        R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_data_format(),
                                   use_regr=True, overlap_thresh=0.7, max_boxes=300)
        # note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
        X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)

        if X2 is None:
            rpn_accuracy_rpn_monitor.append(0)
            rpn_accuracy_for_epoch.append(0)
            continue

        neg_samples = np.where(Y1[0, :, -1] == 1)
        pos_samples = np.where(Y1[0, :, -1] == 0)

        if len(neg_samples) > 0:
            neg_samples = neg_samples[0]
        else:
            neg_samples = []

        if len(pos_samples) > 0:
            pos_samples = pos_samples[0]
        else:
            pos_samples = []

        rpn_accuracy_rpn_monitor.append(len(pos_samples))
        rpn_accuracy_for_epoch.append((len(pos_samples)))

        if C.num_rois > 1:
            if len(pos_samples) < C.num_rois // 2:
                selected_pos_samples = pos_samples.tolist()
            else:
                selected_pos_samples = np.random.choice(pos_samples, C.num_rois // 2,
                                                        replace=False).tolist()
            try:
                selected_neg_samples = np.random.choice(neg_samples,
                                                        C.num_rois - len(selected_pos_samples),
                                                        replace=False).tolist()
            except:
                selected_neg_samples = np.random.choice(neg_samples,
                                                        C.num_rois - len(selected_pos_samples),
                                                        replace=True).tolist()

            sel_samples = selected_pos_samples + selected_neg_samples
        else:
            # in the extreme case where num_rois = 1, we pick a random pos or neg sample
            selected_pos_samples = pos_samples.tolist()
            selected_neg_samples = neg_samples.tolist()
            if np.random.randint(0, 2):
                sel_samples = random.choice(neg_samples)
            else:
                sel_samples = random.choice(pos_samples)

        loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]],
                                                     [Y1[:, sel_samples, :],
                                                      Y2[:, sel_samples, :]])

        losses[iter_num, 0] = loss_rpn[1]
        losses[iter_num, 1] = loss_rpn[2]

        losses[iter_num, 2] = loss_class[1]
        losses[iter_num, 3] = loss_class[2]
        losses[iter_num, 4] = loss_class[3]

        progbar.update(iter_num + 1,
                       [('rpn_cls', losses[iter_num, 0]),
                        ('rpn_regr', losses[iter_num, 1]),
                        ('detector_cls', losses[iter_num, 2]),
                        ('detector_regr', losses[iter_num, 3])])

        iter_num += 1

        if iter_num == epoch_length:
            loss_rpn_cls = np.mean(losses[:, 0])
            loss_rpn_regr = np.mean(losses[:, 1])
            loss_class_cls = np.mean(losses[:, 2])
            loss_class_regr = np.mean(losses[:, 3])
            class_acc = np.mean(losses[:, 4])

            mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(
                rpn_accuracy_for_epoch)
            rpn_accuracy_for_epoch = []

            if C.verbose:
                print(
                    f'Mean number of bounding boxes from RPN overlapping '
                    f'ground truth boxes: {mean_overlapping_bboxes}')
                print(f'Classifier accuracy for bounding boxes from RPN: {class_acc}')
                print(f'Loss RPN classifier: {loss_rpn_cls}')
                print(f'Loss RPN regression: {loss_rpn_regr}')
                print(f'Loss Detector classifier: {loss_class_cls}')
                print(f'Loss Detector regression: {loss_class_regr}')
                print(f'Elapsed time: {time.time() - start_time}')

            curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
            iter_num = 0
            start_time = time.time()

            if curr_loss < best_loss:
                if C.verbose:
                    print(
                        f'Total loss decreased from {best_loss} to {curr_loss}, saving weights')
                best_loss = curr_loss
            model_all.save_weights(model_path_regex.group(1) + "_" + '{:04d}'.format(
                epoch_num) + model_path_regex.group(2))

            break

print('Training complete, exiting.')


Starting training
Epoch 1/2
Mean number of bounding boxes from RPN overlapping ground truth boxes: 24.486
Classifier accuracy for bounding boxes from RPN: 0.576875
Loss RPN classifier: 3.285486337224722
Loss RPN regression: 0.2552903192457743
Loss Detector classifier: 0.7756384629756212
Loss Detector regression: 5.498929038527189
Elapsed time: 836.9372932910919
Total loss decreased from inf to 9.815344157973307, saving weights
Epoch 2/2
Average number of overlapping bounding boxes from RPN = 24.486 for 1000 previous iterations
Mean number of bounding boxes from RPN overlapping ground truth boxes: 24.226
Classifier accuracy for bounding boxes from RPN: 0.65884375
Loss RPN classifier: 3.070572531109231
Loss RPN regression: 0.2294521200035233
Loss Detector classifier: 0.6594145148396492
Loss Detector regression: -3.689131682147039
Elapsed time: 810.2452130317688
Total loss decreased from 9.815344157973307 to 0.2703074838053645, saving weights
Training complete, exiting.


In [27]:
X.shape

(1, 416, 554, 3)