In [43]:
from keras.models import Sequential, Model
from keras.layers import Dense, Dropout, Flatten, BatchNormalization, Input, Reshape, Add, Activation
from keras.layers import Conv2DTranspose, UpSampling2D
from keras.layers import Conv2D, MaxPooling2D, Maximum, AveragePooling2D, Concatenate, LocallyConnected1D, GlobalMaxPooling2D
from keras.initializers import he_normal

In [18]:
img_size = 128
num_classes = 17

In [67]:
def make_linear_bn(_x_, out_channels):
    _x_ = Dense(out_channels, activation='relu')(_x_)
    _x_ = BatchNormalization()(_x_)
    return _x_

def make_conv_bn(_x_, out_channels, kernel_size=1, groups=1):
    _x_ = Conv2D(out_channels, (kernel_size, kernel_size), padding='same', activation='relu')(_x_)
    _x_ = BatchNormalization()(_x_)
    return _x_

def preprocess(_x_):
    _x_ = make_conv_bn(_x_, 16, kernel_size=1)
    _x_ = make_conv_bn(_x_, 16, kernel_size=1)
    _x_ = make_conv_bn(_x_, 16, kernel_size=1)
    _x_ = make_conv_bn(_x_, 16, kernel_size=1)
    return _x_

def create_conv_btneck(_x_, sizes=[64,64,64], groups=1):
    _x_ = make_conv_bn(_x_, sizes[0], kernel_size=1)
    _x_ = make_conv_bn(_x_, sizes[1], kernel_size=3) # in original implementation sometimes used groups=16
    _x_ = make_conv_bn(_x_, sizes[2], kernel_size=1)
    return _x_

def create_cls(_x_, num_cls):
    _x_ = make_linear_bn(_x_, 512)
    _x_ = make_linear_bn(_x_, 512)
    _x_ = make_linear_bn(_x_, num_cls)
    return _x_

In [70]:
input_img = Input(shape=(img_size, img_size, 3))

x = BatchNormalization()(input_img) # 128
x = preprocess(x)

conv1d = create_conv_btneck(x, [32,32,64])
x = MaxPooling2D(pool_size=2, strides=(2,2))(conv1d) # 64

short2d = Conv2D(128, (1,1), padding='same')(x)
conv2d = Add()([create_conv_btneck(x, [64,64,128]), short2d])
x = MaxPooling2D(pool_size=2, strides=(2,2))(conv2d) # 32
logit2d = create_cls(GlobalMaxPooling2D()(x), num_classes)

short3d = Conv2D(256, (1,1), padding='same')(x)
conv3d = Add()([create_conv_btneck(x, [128,128,256]), short3d])
x = MaxPooling2D(pool_size=2, strides=(2,2))(conv3d) # 16
logit3d = create_cls(GlobalMaxPooling2D()(x), num_classes)

short4d = x
conv4d = Add()([create_conv_btneck(x, [256,256,256]), short4d])
x = MaxPooling2D(pool_size=2, strides=(2,2))(conv4d) # 8
logit4d = create_cls(GlobalMaxPooling2D()(x), num_classes)

short5d = x
conv5d = Add()([create_conv_btneck(x, [256,256,256]), short5d])
logit5d = create_cls(GlobalMaxPooling2D()(x), num_classes)

x = Conv2DTranspose(256, kernel_size=1, strides=(2,2))(conv5d) # 16
x = Add()([x, conv4d])
conv4u = create_conv_btneck(x, [256,256,256])
logit4u = create_cls(GlobalMaxPooling2D()(conv4u), num_classes)

x = Conv2DTranspose(256, kernel_size=1, strides=(2,2))(x) # 32
x = Add()([x, conv3d])
conv3u = create_conv_btneck(x, [128,128,128])
logit3u = create_cls(GlobalMaxPooling2D()(conv3u), num_classes)

x = Conv2DTranspose(128, kernel_size=1, strides=(2,2))(x) # 64
x = Add()([x, conv2d])
conv2u = create_conv_btneck(x, [64,64,64])
logit2u = create_cls(GlobalMaxPooling2D()(conv2u), num_classes)

x = Conv2DTranspose(64, kernel_size=1, strides=(2,2))(x) # 128
x = Add()([x, conv1d])
conv1u = create_conv_btneck(x, [64,64,64])
logit1u = create_cls(GlobalMaxPooling2D()(conv1u), num_classes)


out = Add()([logit2d, logit3d, logit4d, logit5d, logit4u, logit3u, logit2u, logit1u])
# out = Dropout(0.15)(logit)
out = Activation('sigmoid')(out)

model = Model(inputs=input_img, outputs=out)
model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_35 (InputLayer)            (None, 128, 128, 3)   0                                            
____________________________________________________________________________________________________
batch_normalization_918 (BatchNo (None, 128, 128, 3)   12                                           
____________________________________________________________________________________________________
conv2d_620 (Conv2D)              (None, 128, 128, 16)  64                                           
____________________________________________________________________________________________________
batch_normalization_919 (BatchNo (None, 128, 128, 16)  64                                           
___________________________________________________________________________________________