In [1]:
import keras
from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, Cropping2D, Concatenate, BatchNormalization, Activation, Flatten
from keras.models import Model
import numpy as np

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
input_image = Input(shape=(None,None,3))

layers=3
scale_depth=5
feature_maps = [64,128,256,512,1024]
bottleneck=16

In [3]:
encoding_layers = [ [ None for y in range( layers ) ] for x in range( scale_depth ) ]

x = BatchNormalization()(input_image)
x = Activation('relu')(x)
encoding_layers[0][0] = Conv2D(feature_maps[0], (3, 3), padding='same')(x)

for i in range(1,scale_depth):
    x = BatchNormalization()(encoding_layers[i-1][0])
    x = Activation('relu')(x)
    encoding_layers[i][0] = Conv2D(feature_maps[i], (3, 3), strides=2, padding='same')(x)
    

#MSD
for scale in range(scale_depth):
    for layer in range(1,layers):
        if scale == 0:
            # layer 1 (2nd layer) layer only has one previous input, so no need to concatenate
            if layer == 1:
                x = BatchNormalization()(encoding_layers[scale][0])
            else:
                to_concatenate = [encoding_layers[scale][i] for i in (range(layer))] #TODO: pass as reference
                x = Concatenate()(to_concatenate)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            encoding_layers[scale][layer] = Conv2D(feature_maps[scale], (3, 3), padding='same')(x)
        else:
            if layer == 1:
                x = BatchNormalization()(encoding_layers[scale-1][0])
            else:
                to_concatenate_prev = [encoding_layers[scale-1][i] for i in (range(layer))]  #TODO: pass as reference
                x = Concatenate()(to_concatenate_prev)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            out_previous = Conv2D(feature_maps[scale], (3, 3), strides=2,padding='same')(x)
            
            if layer == 1:
                x = BatchNormalization()(encoding_layers[scale][0])
            else:
                to_concatenate = [encoding_layers[scale][i] for i in (range(layer))] #TODO: pass as reference
                x = Concatenate()(to_concatenate)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            out_current = Conv2D(feature_maps[scale], (3, 3), padding='same')(x)

            encoding_layers[scale][layer] = Concatenate()([out_previous, out_current])
            

In [4]:
decoding_layers = [ [ None for y in range( layers ) ] for x in range( scale_depth ) ]

for i in range(scale_depth):
    decoding_layers[i][0] = encoding_layers[i][-1]
    
for scale in range(scale_depth)[::-1]:
    for layer in range(1,layers):
        if scale == scale_depth-1:
            if layer == 1:
                x = BatchNormalization()(decoding_layers[scale][0])
            else:
                to_concatenate = [decoding_layers[scale][i] for i in (range(layer))] #TODO: pass as reference
                x = Concatenate()(to_concatenate)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            decoding_layers[scale][layer] = Conv2D(feature_maps[scale], (3, 3), padding='same')(x)
        else:
            if layer == 1:
                x = BatchNormalization()(decoding_layers[scale+1][0])
            else:
                to_concatenate_prev = [decoding_layers[scale+1][i] for i in (range(layer))]  #TODO: pass as reference
                x = Concatenate()(to_concatenate_prev)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            out_previous = Conv2DTranspose(feature_maps[scale], (3, 3), strides=2,padding='same')(x)
            
            if layer == 1:
                x = BatchNormalization()(decoding_layers[scale][0])
            else:
                to_concatenate = [decoding_layers[scale][i] for i in (range(layer))] #TODO: pass as reference
                x = Concatenate()(to_concatenate)
                x = BatchNormalization()(x)
            x = Activation('relu')(x)
            x = Conv2D(bottleneck, (1, 1), padding='same')(x) #TODO: bottleneck
            x = BatchNormalization()(x)
            x = Activation('relu')(x)
            out_current = Conv2D(feature_maps[scale], (3, 3), padding='same')(x)

            decoding_layers[scale][layer] = Concatenate()([out_previous, out_current])

for i in range(scale_depth-1):
    x = BatchNormalization()(decoding_layers[i+1][-1])
    x = Activation('relu')(x)
    x = Conv2DTranspose(feature_maps[i], (3, 3), strides=2, padding='same')(x)
    decoding_layers[i][-1] = Concatenate()([decoding_layers[i][-1],x])

x = BatchNormalization()(decoding_layers[0][-1])
x = Activation('relu')(x)
x = Conv2D(bottleneck, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
mask = Conv2D(1, (1, 1), padding='same', activation='sigmoid')(x)

In [5]:
model = Model(inputs=input_image, outputs=mask)

In [6]:
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, None, None, 3 0                                            
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 3 12          input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 3 0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 6 1792        activation_1[0][0]               
__________________________________________________________________________________________________
batch_norm