In [1]:
import tensorflow as tf

In [66]:
class Residual_Unit(tf.keras.layers.Layer):
    def __init__(self, channels, strides=1, **kwargs):
        super(Residual_Unit, self).__init__(**kwargs)
        self.channels = channels
        self.strides = strides
        
        self.conv1 = tf.keras.layers.Conv2D(filters=self.channels//4, kernel_size=1, strides=self.strides, padding='valid', use_bias=False)
        self.BN1 = tf.keras.layers.BatchNormalization()
        self.relu1 = tf.keras.layers.ReLU()
        
        self.conv2 = tf.keras.layers.Conv2D(filters=self.channels//4, kernel_size=3, strides=1, padding='same', use_bias=False)
        self.BN2 = tf.keras.layers.BatchNormalization()
        self.relu2 = tf.keras.layers.ReLU()
        
        self.conv3 = tf.keras.layers.Conv2D(filters=self.channels, kernel_size=1, strides=1, padding='valid', use_bias=False)
        self.BN3 = tf.keras.layers.BatchNormalization()
        self.relu3 = tf.keras.layers.ReLU()
        
        if self.strides != 1:
            self.conv1a = tf.keras.layers.Conv2D(
                filters=self.channels, 
                kernel_size=1, 
                strides=self.strides, 
                use_bias=False
            )
            self.BN1a = tf.keras.layers.BatchNormalization()
        
    def call(self, inputs):
        if self.strides != 1:
            shortcut = self.conv1a(inputs)
            shortcut = self.BN1a(shortcut)
        else:
            shortcut = inputs
        x = self.conv1(inputs)
        x = self.BN1(x)
        x = self.relu1(x)
        
        x = self.conv2(x)
        x = self.BN2(x)
        x = self.relu2(x)
        
        x = self.conv3(x)
        x = self.BN3(x)
            
        x = tf.add(x, shortcut)
        outputs = self.relu3(x)
        return outputs

In [78]:
class predictor_head(tf.keras.layers.Layer):
    def __init__(self, k=3, **kwargs):
        super(predictor_head, self).__init__(**kwargs)
        self.k = k
        self.window = tf.keras.layers.Conv2D(filters=256, kernel_size=3, strides=1, padding='same')
        self.bbox_reg = tf.keras.layers.Conv2D(filters=self.k*4, kernel_size=1, activation='linear')
        self.bbox_reg_reshape = tf.keras.layers.Reshape((-1, 4), name='reg_out')
        self.cls = tf.keras.layers.Conv2D(filters=self.k, kernel_size=1, activation='sigmoid')
        self.cls_reshape = tf.keras.layers.Reshape((-1, 1), name='cls_out')
        
    def call(self, inputs):
        intermediate = self.window(inputs)
        cls_ = self.cls(intermediate)
        cls = self.cls_reshape(cls_)
        bbox_reg_ = self.bbox_reg(intermediate)
        bbox_reg = self.bbox_reg_reshape(bbox_reg_)
        return cls, bbox_reg

In [83]:
class Feature_Pyramid_Network(tf.keras.models.Model):
    def __init__(self, **kwargs):
        super(Feature_Pyramid_Network, self).__init__(**kwargs)
        self.conv1_conv = tf.keras.layers.Conv2D(64, kernel_size=7, strides=2, padding='same')
        self.conv1_bn = tf.keras.layers.BatchNormalization()
        self.conv1_relu = tf.keras.layers.Activation('relu')
        
        self.conv2_block1 = Residual_Unit(256, strides=2)
        for i in range(2):
            setattr(self, f'conv2_block{i+2}', Residual_Unit(256, strides=1))
        self.featuer2 = tf.keras.layers.Conv2D(256, 1, 1)
        self.pred2 = tf.keras.layers.Conv2D(256, 3, 1, name='P2')
        self.predictor2 = predictor_head()
        
        self.conv3_block1 = Residual_Unit(512, strides=2)
        for i in range(3):
            setattr(self, f'conv3_block{i+2}', Residual_Unit(512, strides=1))
        self.featuer3 = tf.keras.layers.Conv2D(256, 1, 1)
        self.upsmapling3 = tf.keras.layers.UpSampling2D(size=(2, 2), data_format='channels_last', interpolation='bilinear')
        self.pred3 = tf.keras.layers.Conv2D(256, 3, 1, name='P3')
        self.predictor3 = predictor_head()
        
        self.conv4_block1 = Residual_Unit(1024, strides=2)
        for i in range(22):
            setattr(self, f'conv4_block{i+2}', Residual_Unit(1024, strides=1))
        self.featuer4 = tf.keras.layers.Conv2D(256, 1, 1)
        self.upsmapling4 = tf.keras.layers.UpSampling2D(size=(2, 2), data_format='channels_last', interpolation='bilinear')
        self.pred4 = tf.keras.layers.Conv2D(256, 3, 1, name='P4')
        self.predictor4 = predictor_head()
        
        self.conv5_block1 = Residual_Unit(2048, strides=2)
        for i in range(2):
            setattr(self, f'conv5_block{i+2}', Residual_Unit(2048, strides=1))
        self.featuer5 = tf.keras.layers.Conv2D(256, 1, 1, name='P5')
        self.predictor5 = predictor_head()
        self.upsmapling5 = tf.keras.layers.UpSampling2D(size=(2, 2), data_format='channels_last', interpolation='bilinear')
        
    def compile(self, optimizer, **kwargs):
        super(RPN, self).compile(**kwargs)
        self.optimizer = optimizer
        self.loss_tracker = tf.keras.metrics.Mean(name='loss')

    def Cls_Loss(self, y_true, y_pred):
        indices = tf.where(tf.not_equal(y_true, tf.constant(-1.0, dtype=tf.float64)))
        target = tf.gather_nd(y_true, indices)
        output = tf.gather_nd(y_pred, indices)
        return tf.losses.BinaryCrossentropy()(target, output)

    def Reg_Loss(self, y_true, y_pred):
        indices = tf.reduce_any(tf.not_equal(y_true, 0), axis=-1)
        loss_fn = tf.losses.Huber(reduction=tf.losses.Reduction.NONE)
        loss_for_all = loss_fn(y_true[indices], y_pred[indices])
        loss_for_all = tf.reduce_mean(loss_for_all, axis=-1)
        return loss_for_all
        
    def train_step(self, data):
        x, y = data
        y_cls = y[0]
        y_reg = y[1]
        rpn_lambda = 10
        
        with tf.GradientTape() as tape:
            pred = self(x, training=True)
            losses = 0
            for cls, bbox_reg, _ in pred:
                cls_loss = self.Cls_Loss(y_cls, cls)
                reg_loss = self.Reg_Loss(y_reg, bbox_reg)
                losses += cls_loss + rpn_lambda * reg_loss
            
        trainable_vars = self.trainable_variables
        grad = tape.gradient(losses, trainable_vars)
        self.optimizer.apply_gradients(zip(grad, trainable_vars))
        self.loss_tracker.update_state(losses)
        return {'fpn_loss': self.loss_tracker.result()}
        
    def call(self, inputs):
        c1 = self.conv1_conv(inputs)
        c1 = self.conv1_bn(c1)
        c1 = self.conv1_relu(c1)
        
        c2 = self.conv2_block1(c1)
        for i in range(2):
            c2 = getattr(self, f'conv2_block{i+2}')(c2)
        f2 = self.featuer2(c2)
        
        c3 = self.conv3_block1(c2)
        for i in range(3):
            c3 = getattr(self, f'conv3_block{i+2}')(c3)
        f3 = self.featuer3(c3)
            
        c4 = self.conv4_block1(c3)
        for i in range(22):
            c4 = getattr(self, f'conv4_block{i+2}')(c4)
        f4 = self.featuer4(c4)
            
        c5 = self.conv5_block1(c4)
        for i in range(2):
            c5 = getattr(self, f'conv5_block{i+2}')(c5)
        P5 = self.featuer5(c5)
        cls5, bbox_reg5 = self.predictor5(P5)
        
        up5 = self.upsmapling5(P5)
        M4 = tf.add(up5, f4)
        P4 = self.pred4(M4)
        cls4, bbox_reg4 = self.predictor4(P4)
        
        up4 = self.upsmapling4(M4)
        M3 = tf.add(up4, f3)
        P3 = self.pred3(M3)
        cls3, bbox_reg3 = self.predictor3(P3)
        
        up3 = self.upsmapling3(M3)
        M2 = tf.add(up3, f2)
        P2 = self.pred2(M2)
        cls2, bbox_reg2 = self.predictor5(P2)
        
        return (cls2, bbox_reg2, P2), (cls3, bbox_reg3, P3), (cls4, bbox_reg4, P4), (cls5, bbox_reg5, P5)

In [84]:
fpn = Feature_Pyramid_Network()
inputs = tf.keras.layers.Input(shape=(224, 224, 3))
outputs = fpn(inputs)

In [85]:
outputs

((<tf.Tensor 'feature__pyramid__network_19/predictor_head_11/cls_out/Reshape_1:0' shape=(None, 8748, 1) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/predictor_head_11/reg_out/Reshape_1:0' shape=(None, 8748, 4) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/P2/BiasAdd:0' shape=(None, 54, 54, 256) dtype=float32>),
 (<tf.Tensor 'feature__pyramid__network_19/predictor_head_9/cls_out/Reshape:0' shape=(None, 2028, 1) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/predictor_head_9/reg_out/Reshape:0' shape=(None, 2028, 4) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/P3/BiasAdd:0' shape=(None, 26, 26, 256) dtype=float32>),
 (<tf.Tensor 'feature__pyramid__network_19/predictor_head_10/cls_out/Reshape:0' shape=(None, 432, 1) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/predictor_head_10/reg_out/Reshape:0' shape=(None, 432, 4) dtype=float32>,
  <tf.Tensor 'feature__pyramid__network_19/P4/BiasAdd:0' shape=(None, 12, 12, 256) dtype=flo