In [1]:
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tqdm import tqdm

import os

In [None]:
from utils.utils import *
from utils.intersection_over_union import IoU_np, IoU_tensor
from utils.Augmentation import *
from utils.label_generator import *
from utils.regional_interest_projection import *

In [2]:
df = pd.read_csv('./res/train_df.csv')
image = cv2.imread('./res/train_imgs/001-1-1-01-Z17_A-0000001.jpg', cv2.COLOR_BGR2RGB)
inputs = tf.expand_dims(image, 0)

In [3]:
# plt.imshow(image)
# plt.axis('off')
# plt.show()

# Faster R-CNN

### Image resize

In [19]:
train = df_resize(df, Rx, Ry)

### Ground Truth Generating 

In [None]:
ground_truth = ground_truth_generator(train)
GT = np.array(ground_truth[['x', 'y', 'w', 'h']])

### Anchor boxes

In [30]:
scales = [140, 160, 180, 210, 240]
ratio = [(1/np.sqrt(3), np.sqrt(3)), (1/np.sqrt(2), np.sqrt(2)), (1, 1), (np.sqrt(2), 1/np.sqrt(2)), (np.sqrt(3), 1/np.sqrt(3))]
anchor_boxes = Anchor_Boxes(image1.shape, scales, ratio, model='vgg')

In [31]:
bboxes = anchors_to_coordinates(anchor_boxes)
out_boundaries_indxes = (np.where(bboxes[:, 0] < 0) or np.where(bboxes[:, 2] < 0) or np.where(bboxes[:, 1] > 768) or np.where(bboxes[:, 3] > 432))[0]

### Label Generating

In [None]:
def traingtGenerator():
    Rx, Ry = 0.4, 0.4
    image_size = (1080, 1920, 3)
    size = [int(image_size[0] * Rx), int(image_size[1] * Ry)]
    iter_num = len(train)

    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3) 
        img = tf.image.resize(img, size) 
        img = img/255                         
        target = list(train.iloc[:,1:49].iloc[i,:])
        gt, gt_min = gt_generator(target)
        mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
        cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)

        yield img, (cls_label, reg_label, gt, mask)
    
    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3) 
        img = tf.image.resize(img, size) 
        img = img/255
        target = train.iloc[:,1:49].iloc[i,:] 
        img, target = left_right_flip(img, target)
        gt, gt_min = gt_generator(target)
        mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
        cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)

        yield img, (cls_label, reg_label, gt, mask)

    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, size)
        img = img/255
        target = train.iloc[:,1:49].iloc[i,:]
        img_list, target_list = shift_images(img, target)
        for shifted_img, shifted_target in zip(img_list, target_list):
            gt, gt_min = gt_generator(shifted_target)
            mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
            cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)

            yield shifted_img, (cls_label, reg_label, gt, mask)

    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, size)
        img = img/255
        target = train.iloc[:,1:49].iloc[i,:]
        img_list, target_list = rotate_augmentation(img, target)
        for rotated_img, rotated_target in zip(img_list, target_list):
            gt, gt_min = gt_generator(rotated_target)
            mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
            cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)

            yield rotated_img, (cls_label, reg_label, gt, mask)

    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, size)
        img = img/255
        target = train.iloc[:,1:49].iloc[i,:]
        gt, gt_min = gt_generator(target)
        mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
        cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)
        img_list = alter_brightness(img)
        for altered_brightness_images in img_list:

            yield altered_brightness_images, (cls_label, reg_label, gt, mask)

    for i in range(iter_num):
        img = tf.io.read_file(train_val_dir + 'train/' + train['image'].iloc[i]) 
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, size)
        img = img/255
        target = train.iloc[:,1:49].iloc[i,:]
        noisy_img = add_noise(img)
        gt, gt_min = gt_generator(target)
        mask = [(target[j] - gt_min[0]) / gt[2] if j % 2 == 0 else (target[j] - gt_min[1]) / gt[3] for j in range(len(target))]
        cls_label, reg_label = label_generator(gt, anchor_boxes, out_boundaries_indxes)

        yield noisy_img, (cls_label, reg_label, gt, mask)

In [None]:
batch_size = 16
train_dataset = tf.data.Dataset.from_generator(
    traingtGenerator,
    output_signature = (
            tf.TensorSpec(shape=(size[0], size[1], 3)),
            (
                tf.TensorSpec(shape=(len(anchor_boxes))),
                tf.TensorSpec(shape=(len(anchor_boxes),4)),
                tf.TensorSpec(shape=(4)),
                tf.TensorSpec(shape=(48))
            )
        )
).batch(batch_size).prefetch(16*4)

## Region Proposal Network

In [44]:
class RPN(tf.keras.models.Model):
    def __init__(self, base_model, anchor_boxes, k=9, n_sample=32, **kwargs):
        super(RPN, self).__init__(**kwargs)
        self.base_model = base_model
        self.anchor_boxes = anchor_boxes
        self.num_of_anchor = len(self.anchor_boxes)
        self.n_sample = n_sample
        self.k = k

        self.window = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same')
        self.bbox_reg = tf.keras.layers.Conv2D(filters=self.k*4, kernel_size=1)
        self.bbox_reg_reshape = tf.keras.layers.Reshape((-1, 4), name='reg_out')
        self.cls = tf.keras.layers.Conv2D(filters=self.k, kernel_size=1, activation='sigmoid')
        self.cls_reshape = tf.keras.layers.Reshape((-1, 1), name='cls_out')

        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.test_loss_tracker = tf.keras.metrics.Mean(name='test_loss')

    def compile(self, optimizer):
        super(RPN, self).compile()
        self.optimizer = optimizer
    
    def Cls_Loss(self, y_true, y_pred):
        indices = tf.where(tf.not_equal(y_true, tf.constant(-1.0, dtype=tf.float32)))
        target = tf.gather_nd(y_true, indices)
        output = tf.gather_nd(y_pred, indices)
        return tf.losses.BinaryCrossentropy(reduction=tf.losses.Reduction.SUM)(target, output)/self.n_sample

    def Reg_Loss(self, y_true, y_pred):
        indices = tf.reduce_any(tf.not_equal(y_true, 0), axis=-1)
        return tf.losses.Huber(reduction=tf.losses.Reduction.SUM)(y_true[indices], y_pred[indices])/self.num_of_anchor
    
    def train_step(self, data):
        x, y = data
        y_cls = y[0]
        y_reg = y[1]
        rpn_lambda = 10
        
        with tf.GradientTape() as tape:
            cls, bbox_reg, _ = self(x, training=True)
            cls_loss = self.Cls_Loss(y_cls, cls)
            reg_loss = self.Reg_Loss(y_reg, bbox_reg)
            losses = cls_loss + rpn_lambda * reg_loss
            
        trainable_vars = self.trainable_variables
        grad = tape.gradient(losses, trainable_vars)
        self.optimizer.apply_gradients(zip(grad, trainable_vars))
        self.loss_tracker.update_state(losses)
        return {'rpn_loss': self.loss_tracker.result()}

    def test_step(self, data):
        x, y = data
        y_cls = y[0]
        y_reg = y[1]
        
        cls, bbox_reg, _ = self(x, training=False)
        cls_loss = self.Cls_Loss(y_cls, cls)
        reg_loss = self.Reg_Loss(y_reg, bbox_reg)
        losses = cls_loss + rpn_lambda * reg_loss

        self.test_loss_tracker.update_state(losses)
        return {'rpn_loss_val': self.test_loss_tracker.result()}

    def bbox_regression(self, boxes):
        tx = (boxes[:, :, 0] - self.anchor_boxes[:, 0]) / self.anchor_boxes[:, 2]
        ty = (boxes[:, :, 1] - self.anchor_boxes[:, 1]) / self.anchor_boxes[:, 3]
        tw = tf.math.log(tf.maximum(boxes[:, :, 2], np.finfo(np.float64).eps) / self.anchor_boxes[:, 2])
        th = tf.math.log(tf.maximum(boxes[:, :, 3], np.finfo(np.float64).eps) / self.anchor_boxes[:, 3])
        return tf.stack([tx, ty, tw, th], -1)

    def inverse_bbox_regression(self, boxes):
        gx = self.anchor_boxes[:, 2] * boxes[:, :, 0] + self.anchor_boxes[:, 0]
        gy = self.anchor_boxes[:, 3] * boxes[:, :, 1] + self.anchor_boxes[:, 1]
        gw = self.anchor_boxes[:, 2] * tf.exp(boxes[:, :, 2])
        gh = self.anchor_boxes[:, 3] * tf.exp(boxes[:, :, 3])
        return tf.stack([gx, gy, gw, gh], axis=-1)

    def call(self, inputs):
        feature_extractor = self.base_model(inputs)
        intermediate = self.window(feature_extractor)
        cls = self.cls(intermediate)
        cls = self.cls_reshape(cls)
        bbox_reg = self.bbox_reg(intermediate)
        bbox_reg = self.bbox_reg_reshape(bbox_reg)
        bbox_reg = self.bbox_regression(bbox_reg)
        return cls, bbox_reg, feature_extractor

##  Classifier Network

In [None]:
class get_candidate_layer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(get_candidate_layer, self).__init__(**kwargs)

    def anchors_clip(self, boxes, size=(432, 768)):    
        x1 = boxes[:, :, 0] - boxes[:, :, 2]/2
        x2 = boxes[:, :, 0] + boxes[:, :, 2]/2
        y1 = boxes[:, :, 1] - boxes[:, :, 3]/2
        y2 = boxes[:, :, 1] + boxes[:, :, 3]/2
        
        x1 = tf.clip_by_value(x1, 0, size[1])
        x2 = tf.clip_by_value(x2, 0, size[1])
        y1 = tf.clip_by_value(y1, 0, size[0])
        y2 = tf.clip_by_value(y2, 0, size[0])

        w = x2 - x1
        h = y2 - y1
        x = x1 + w/2
        y = y1 + h/2
        return tf.stack([x, y, w, h], axis=-1)

    def call(self, x):
        scores, rps, n_train_pre_nms = x
        rois = self.anchors_clip(rps)

        oobw = tf.expand_dims(tf.cast(tf.math.greater(rois[:, :, 2], 16), tf.float32), -1)
        oobh = tf.expand_dims(tf.cast(tf.math.greater(rois[:, :, 3], 16), tf.float32), -1)
        scores = tf.math.multiply(scores, oobw)
        scores = tf.math.multiply(scores, oobh)

        orders = tf.argsort(scores, direction='DESCENDING', axis=1)[:, :n_train_pre_nms]
        rois = tf.gather_nd(rois, orders, batch_dims=1)
        scores = tf.gather_nd(scores, orders, batch_dims=1)
        return rois, scores

In [None]:
class NMS(tf.keras.layers.Layer):
    def __init__(self, iou_threshold=0.7, **kwargs):
        self.iou_threshold = iou_threshold
        super(NMS, self).__init__(**kwargs)

    def call(self, inputs):
        rois, scores, max_output_size = inputs
        selected_indices_padded = tf.image.non_max_suppression_padded(
            rois, 
            tf.squeeze(scores), 
            max_output_size=max_output_size,
            iou_threshold=0.7,
            pad_to_max_output_size=True
        )[0]
        nms = tf.gather(rois, selected_indices_padded, batch_dims=1)
        return nms

In [None]:
class RoIpool(tf.keras.layers.Layer):
    def __init__(self, pool_size=7, num_rois=128, batch_size=16, **kwargs):
        self.pool_size = pool_size
        self.num_rois = num_rois
        self.batch_size = batch_size
        super(RoIpool, self).__init__(**kwargs)

    def cal_rois_ratio(self, nmses, size=[432, 768]):
        x1 = (nmses[:, :, 0] - nmses[:, :, 2]/2)/size[1]
        x2 = (nmses[:, :, 0] + nmses[:, :, 2]/2)/size[1]
        y1 = (nmses[:, :, 1] - nmses[:, :, 3]/2)/size[0]
        y2 = (nmses[:, :, 1] + nmses[:, :, 3]/2)/size[0]
        return tf.stack([y1, x1, y2, x2], axis=-1)

    def call(self, inputs):
        feature_map, nmses = inputs
        n_channel = feature_map.shape[-1]
        nmses = self.cal_rois_ratio(nmses)
        rois = tf.image.crop_and_resize(
            feature_map, 
            tf.reshape(nmses, (-1, 4)), 
            box_indices=[i for i in range(self.batch_size) for _ in range(self.num_rois)], 
            crop_size=[self.pool_size, self.pool_size]
        )
        return tf.reshape(rois, shape=(self.batch_size, self.num_rois, self.pool_size, self.pool_size, n_channel))

In [None]:
class Classifier(tf.keras.models.Model):
    def __init__(self, classifier_lambda=1, **kwargs):
        super(Classifier, self).__init__(**kwargs)
        self.classifier_lambda = classifier_lambda

        self.conv = tf.keras.layers.Conv2D(2048, 7, 7, name='cls_conv')
        self.flatten = tf.keras.layers.Flatten(name='cls_flatten')
        self.dense = tf.keras.layers.Dense(2048, name='cls_dense')
        self.cls = tf.keras.layers.Dense(1, activation='sigmoid', name='cls_out')
        self.bbox_reg = tf.keras.layers.Dense(4, name='bbox_out')

        self.deconv = tf.keras.layers.Conv2DTranspose(filters=256, kernel_size=(2, 2), strides=2)
        self.bn = tf.keras.layers.BatchNormalization()
        self.relu = tf.keras.layers.ReLU()
        
        self.conv1 = tf.keras.layers.Conv2D(filters=80, kernel_size=1)
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.relu1 = tf.keras.layers.ReLU()

        self.conv2 = tf.keras.layers.Conv2D(80, 14, 14)
        self.flatten2 = tf.keras.layers.Flatten()
        self.mask_out = tf.keras.layers.Dense(48, activation='sigmoid')

    def compile(self, optimizer):
        super(Classifier, self).compile()
        self.optimizer = optimizer
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')
        self.test_loss_tracker = tf.keras.metrics.Mean(name='test_loss')
    
    def Cls_Loss(self, y_true, y_pred):
        return tf.losses.BinaryCrossentropy()(y_true, y_pred)

    def Reg_Loss(self, y_true, y_pred, indices):
        return tf.losses.Huber()(y_true[indices], y_pred[indices])

    def Mask_Loss(self, y_true, y_pred, indices):
        y_true = tf.reshape(tf.tile(y_true, [1, indices.shape[1]]), (-1, indices.shape[1], 48))
        return tf.keras.losses.MSE(y_true[indices], y_pred[indices])
    
    def train_step(self, data):
        x, y = data
        y_cls = y[0]
        y_reg = y[1]
        y_mask = y[2]
        indices = tf.not_equal(y_cls, 0)
        
        with tf.GradientTape() as tape:
            cls, bbox_reg, mask, _ = self(x)
            cls_loss = self.Cls_Loss(y_cls, cls)
            reg_loss = self.Reg_Loss(y_reg, bbox_reg, indices)
            mask_loss = self.Mask_Loss(y_mask, mask, indices)
            losses = cls_loss + reg_loss + mask_loss
            
        trainable_vars = self.trainable_variables
        grad = tape.gradient(losses, trainable_vars)
        self.optimizer.apply_gradients(zip(grad, trainable_vars))
        self.loss_tracker.update_state(losses)
        return {'classifier_loss': self.loss_tracker.result()}

    def test_step(self, data):
        x, y = data
        y_cls = y[0]
        y_reg = y[1]
        y_mask = y[2]
        indices = tf.not_equal(y_cls, 0)

        cls, bbox_reg, mask, _ = self(x, training=False)
        cls_loss = self.Cls_Loss(y_cls, cls)
        reg_loss = self.Reg_Loss(y_reg, bbox_reg, indices)
        mask_loss = self.Mask_Loss(y_mask, mask, indices)
        losses = cls_loss + reg_loss + mask_loss

        self.test_loss_tracker.update_state(losses)
        return {'classifier_loss_val': self.test_loss_tracker.result()}

    def bbox_regression(self, bbox, nmses):
        tx = (bbox[:, :, 0] - nmses[:, :, 0]) / nmses[:, :, 2]
        ty = (bbox[:, :, 1] - nmses[:, :, 1]) / nmses[:, :, 3]
        tw = tf.math.log(tf.maximum(bbox[:, :, 2], np.finfo(np.float64).eps) / nmses[:, :, 2])
        th = tf.math.log(tf.maximum(bbox[:, :, 3], np.finfo(np.float64).eps) / nmses[:, :, 3])
        return tf.stack([tx, ty, tw, th], -1)

    @staticmethod
    def inverse_bbox_regression(bbox, nmses):
        gx = nmses[:, :, 2] * bbox[:, :, 0] + nmses[:, :, 0]
        gy = nmses[:, :, 3] * bbox[:, :, 1] + nmses[:, :, 1]
        gw = nmses[:, :, 2] * tf.exp(bbox[:, :, 2])
        gh = nmses[:, :, 3] * tf.exp(bbox[:, :, 3])
        return tf.stack([gx, gy, gw, gh], -1)

    def call(self, inputs):
        rois, nms = inputs

        x = tf.keras.layers.TimeDistributed(self.deconv)(rois)
        x = tf.keras.layers.TimeDistributed(self.bn)(x)
        x = tf.keras.layers.TimeDistributed(self.relu)(x) 

        x = tf.keras.layers.TimeDistributed(self.conv1)(x)
        x = tf.keras.layers.TimeDistributed(self.bn1)(x)
        x = tf.keras.layers.TimeDistributed(self.relu1)(x) 

        x = tf.keras.layers.TimeDistributed(self.conv2)(x)
        x = tf.keras.layers.TimeDistributed(self.flatten2)(x) 
        mask = tf.keras.layers.TimeDistributed(self.mask_out)(x)

        x = tf.keras.layers.TimeDistributed(self.conv)(rois)
        x = tf.keras.layers.TimeDistributed(self.flatten)(x)
        feature_vector = tf.keras.layers.TimeDistributed(self.dense)(x)
        clss = tf.keras.layers.TimeDistributed(self.cls)(feature_vector)
        bbox = tf.keras.layers.TimeDistributed(self.bbox_reg)(feature_vector)
        bbox_reg = self.bbox_regression(bbox, nms)

        return clss, bbox_reg, mask, nms

## Faster R-CNN

In [None]:
class Faster_RCNN(tf.keras.models.Model):
    def __init__(self, img_size, anchor_boxes, k, n_sample, backbone, rpn_lambda, pool_size, num_rois, batch_size, classifier_lambda, **kwargs):
        super(Faster_RCNN, self).__init__(*kwargs)
        self.img_size = img_size
        self.anchor_boxes = anchor_boxes
        self.k = k
        self.n_sample = n_sample
        self.backbone = backbone
        self.rpn_lambda = rpn_lambda
        self.pool_size = pool_size
        self.num_rois = num_rois
        self.batch_size = batch_size
        self.classifier_lambda = classifier_lambda
        self.n_train_pre_nms = 12000
        self.n_train_post_nms = 2000
        self.n_test_pre_nms = 6000
        self.n_test_post_nms = 128
        self.iou_threshold = 0.7

        self.rpn = RPN(img_size= self.img_size, anchor_boxes=self.anchor_boxes, k=self.k, n_sample=self.n_sample, backbone=self.backbone, rpn_lambda=self.rpn_lambda, name='rpn')
        self.get_candidate = get_candidate_layer(name='get_candidate')
        self.get_nms = NMS(iou_threshold=self.iou_threshold, name='get_nms')
        self.roipool = RoIpool(pool_size=self.pool_size, num_rois=self.num_rois, batch_size=self.batch_size, name='roipool')
        self.classifier = Classifier(classifier_lambda=self.classifier_lambda, name='classifier')
        self.train_stage = None

    def compile(self, rpn_optimizer, classifier_optimizer):
        super(Faster_RCNN, self).compile()
        self.rpn.compile(optimizer=rpn_optimizer)
        self.classifier.compile(optimizer=classifier_optimizer)

    def call(self, inputs):
        scores, rps, feature_map = self.rpn(inputs)
        rps = self.rpn.inverse_bbox_regression(rps)
        candidate_area, scores = self.get_candidate((scores, rps, self.n_test_pre_nms))
        nms = self.get_nms((candidate_area, scores, self.n_test_post_nms))
        rois = self.roipool((feature_map, nms))
        cls, bbox_reg, mask, nms = self.classifier((rois, nms))
        predict = self.classifier.inverse_bbox_regression(bbox_reg, nms)
        return cls, predict, mask

In [None]:
tf.keras.backend.clear_session()

img_size=image1.shape
anchor_boxes=anchor_boxes
k=5*5
n_sample=32
backbone='resnet50'
rpn_lambda=10**3
pool_size=7
num_rois=128
batch_size=batch_size
classifier_lambda=10

frcnn = Faster_RCNN(
    img_size=img_size, 
    anchor_boxes=anchor_boxes, 
    k=k, 
    n_sample=n_sample, 
    backbone=backbone,
    rpn_lambda=rpn_lambda, 
    pool_size=pool_size,
    num_rois=num_rois,
    batch_size=batch_size,
    classifier_lambda=classifier_lambda
)
# frcnn.load_weights("./frcnn")
frcnn.rpn.load_weights('./rpn')

# frcnn.save_weights("./frcnn")
# frcnn.classifier.save_weights("./frcnn_classifier")
# frcnn.rpn.save_weights("./frcnn_rpn")

frcnn.compile(
    rpn_optimizer = tf.keras.optimizers.Adam(lr=0.01),
    classifier_optimizer = tf.keras.optimizers.Adam(lr=0.001)
)

In [None]:
def frcnn_train_step(model, train_dataset, train_stage, epochs=1, valid_dataset=None, change_lr=False, rpn_lr=None, cls_lr=None):
    if change_lr:
        if rpn_lr:
            tf.keras.backend.set_value(model.rpn.optimizer.learning_rate, rpn_lr)
        if cls_lr:
            tf.keras.backend.set_value(model.classifier.optimizer.learning_rate, cls_lr)

    if train_stage == 1:
        print('Train RPNs \n')
        model.rpn.trainable = True
        model.classifier.trainable = False
    elif train_stage == 2:
        print('Train Fast R-CNN using the proposals from RPNs \n')
        model.rpn.trainable = False
        model.rpn.base_model.trainable = True
        model.classifier.trainable = True
    elif train_stage == 3:
        print('Fix the shared convolutional layers and fine-tune unique layers to RPN \n')
        model.rpn.trainable = True
        model.rpn.base_model.trainable = False
        model.classifier.trainable = False
    elif train_stage == 4:
        print('Fine-tune unique layers to Fast R-CNN \n')
        model.rpn.trainable = False
        model.classifier.trainable = True

    for epoch in range(epochs):
        print(f"epoch {epoch+1}/{epochs}")
        display_loss = display("Training loss (for one batch) at step 0 : 0", display_id=True)
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            y_cls_rpn = y_batch_train[0]
            y_reg_rpn = y_batch_train[1]
            gts = y_batch_train[2]
            mask = y_batch_train[3]
            
            if train_stage == 1 or train_stage == 3:
                result = model.rpn.train_step((x_batch_train, (y_cls_rpn, y_reg_rpn)))
                losses = result['rpn_loss'].numpy()
            else:
                scores, rps, feature_map = model.rpn(x_batch_train)
                rps = model.rpn.inverse_bbox_regression(rps)
                candidate_area, scores = model.get_candidate((scores, rps, model.n_train_pre_nms))
                nms = model.get_nms((candidate_area, scores, model.n_train_post_nms))
                box_labels, cls_labels, nms = classifier_label_generator(nms, gts)
                rois = model.roipool((feature_map, nms))
                result = model.classifier.train_step(((rois, nms), (cls_labels, box_labels, mask)))
                losses = result['classifier_loss'].numpy()

            display_loss.update(f"Training loss at step {step} : {losses}")

        if valid_dataset is not None:
            display_loss_valid = display("validation loss : 0", display_id=True)
            for x_batch_test, y_batch_test in valid_dataset:
                y_cls_rpn = y_batch_test[0]
                y_reg_rpn = y_batch_test[1]
                gts = y_batch_test[2]
                mask = y_batch_test[3]

                if train_stage == 1 or train_stage == 3:
                    result = model.rpn.train_step((x_batch_test, (y_cls_rpn, y_reg_rpn)))
                    losses = result['rpn_loss'].numpy()
                else:
                    scores, rps, feature_map = model.rpn(x_batch_test)
                    rps = model.rpn.inverse_bbox_regression(rps)
                    candidate_area, scores = model.get_candidate((scores, rps, model.n_test_pre_nms))
                    nms = model.get_nms((candidate_area, scores, model.n_test_post_nms))
                    box_labels, cls_labels, nms = classifier_label_generator(nms, gts)
                    rois = model.roipool((feature_map, nms))
                    result = model.classifier.train_step(((rois, nms), (cls_labels, box_labels, mask)))
                    losses = result['classifier_loss'].numpy()
                
            display_loss_valid.update(f"validation loss : {losses}")
    return model

In [None]:
for i in [1, 2, 3, 4]:
    frcnn = frcnn_train_step(
        model=frcnn, 
        train_dataset=train_dataset, 
        train_stage=i,
        epochs=10
    )

In [None]:
# frcnn.save_weights("./frcnn")
# frcnn.classifier.save_weights("./frcnn_classifier")
# frcnn.rpn.save_weights("./frcnn_rpn")