**Helper functions to construct the network**

In [0]:
def _bn_relu(input):
    """Helper to build a BN -> relu block
    """
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
    return Activation("relu")(norm)

In [0]:
def _conv_bn_relu(**conv_params):
    """Helper to build a conv -> BN -> relu block
    """
    nb_filter = conv_params["nb_filter"]
    nb_row = conv_params["nb_row"]
    nb_col = conv_params["nb_col"]
    subsample = conv_params.setdefault("subsample", (1, 1))
    init = conv_params.setdefault("init", "he_normal")
    border_mode = conv_params.setdefault("border_mode", "same")
    W_regularizer = conv_params.setdefault("W_regularizer", l2(1.e-4))

    def f(input):
        #conv = Convolution2D(nb_filter=nb_filter, nb_row=nb_row, nb_col=nb_col, subsample=subsample,
        #                     init=init, border_mode=border_mode, W_regularizer=W_regularizer)(input)
        conv = Conv2D(filters=nb_filter, kernel_size=(nb_col,nb_row), strides=subsample,
                      kernel_initializer=init, padding=border_mode, kernel_regularizer=W_regularizer)(input)
        
        return _bn_relu(conv)

    return f

In [0]:
def _bn_relu_conv(**conv_params):
    """Helper to build a BN -> relu -> conv block.
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    # Furthermore, see also https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035
    # to justify the order of these three elements BN->relu->conv block  (it is 
    # called full-preactivation)
    nb_filter = conv_params["nb_filter"]
    nb_row = conv_params["nb_row"]
    nb_col = conv_params["nb_col"]
    subsample = conv_params.setdefault("subsample", (1,1))
    init = conv_params.setdefault("init", "he_normal")
    border_mode = conv_params.setdefault("border_mode", "same")
    W_regularizer = conv_params.setdefault("W_regularizer", l2(1.e-4))

    def f(input):
        activation = _bn_relu(input)
        #return Convolution2D(nb_filter=nb_filter, nb_row=nb_row, nb_col=nb_col, subsample=subsample,
        #                     init=init, border_mode=border_mode, W_regularizer=W_regularizer)(activation)
        return Conv2D(filters=nb_filter, kernel_size=(nb_col,nb_row), strides=subsample,
                      kernel_initializer=init, padding=border_mode, kernel_regularizer=W_regularizer)(activation)


    return f

In [0]:
def _shortcut(input, residual):
    """Adds a shortcut between input and residual block and merges them with "sum"
    """
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    input_shape = K.int_shape(input)
    residual_shape = K.int_shape(residual)
    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

    shortcut = input  # identity

    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        #shortcut = Convolution2D(nb_filter=residual_shape[CHANNEL_AXIS],
        #                         nb_row=1, nb_col=1,
        #                         subsample=(stride_width, stride_height),
        #                         init="he_normal", border_mode="valid",
        #                         W_regularizer=l2(0.0001))(input)
        shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
                          kernel_size=(1,1),
                          strides=(stride_width, stride_height),
                          kernel_initializer="he_normal", padding="valid",
                          kernel_regularizer=l2(0.0001))(input)
        

    #return merge([shortcut, residual], mode="sum")
    return concatenate([shortcut, residual])

In [0]:
def _residual_block(block_function, nb_filter, repetitions, is_first_layer=False):
    """Builds a residual block with repeating bottleneck blocks.
    """
    def f(input):
        for i in range(repetitions):
            init_subsample = (1, 1)
            if i == 0 and not is_first_layer:
                init_subsample = (2, 2)
            # remeber that block_function is a function pointer that will
            # point to basic_block or bottleneck.
            # each call to this function returns a 
            input = block_function(nb_filter=nb_filter, init_subsample=init_subsample,
                                   is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
        return input

    return f

In [0]:
def basic_block(nb_filter, init_subsample=(1, 1), is_first_block_of_first_layer=False):
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
    Follows improved proposed scheme in http://arxiv.org/pdf/ 1603.05027v2.pdf
    """
    # the scheme that this function creates is the following: 
    # input -> layer(bn-relu-conv) -> layer(bn-relu-conv) -> output
    # with a shortcut from input to output 
    def f(input):

        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv2D(filters=nb_filter,
                           kernel_size=(3,3),
                           strides=init_subsample,
                           kernel_initializer="he_normal", padding="same",
                           kernel_regularizer=l2(0.0001))(input)
        else:
            conv1 = _bn_relu_conv(nb_filter=nb_filter, nb_row=3, nb_col=3,
                                  subsample=init_subsample)(input)

        residual = _bn_relu_conv(nb_filter=nb_filter, nb_row=3, nb_col=3)(conv1)
        return _shortcut(input, residual)

    return f

In [0]:
def bottleneck(nb_filter, init_subsample=(1, 1), is_first_block_of_first_layer=False):
    """Bottleneck architecture for > 34 layer resnet.
    Follows improved proposed scheme in http://arxiv.org/pdf/1603.05027v2.pdf

    Returns:
        A final conv layer of nb_filter * 4
    """
    def f(input):

        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            # conv_1_1 = Conv2D(nb_filter=nb_filter,
            #                     nb_row=1, nb_col=1,
            #                     subsample=init_subsample,
            #                     init="he_normal", border_mode="same",
            #                     W_regularizer=l2(0.0001))(input)
            conv_1_1 = Conv2D(filters=nb_filter,
                              kernel_size=(1,1),
                              strides=init_subsample,
                              kernel_initializer="he_normal", padding="same",
                              kernel_regularizer=l2(0.0001))(input)
        else:
            conv_1_1 = _bn_relu_conv(nb_filter=nb_filter, nb_row=1, nb_col=1,
                                     subsample=init_subsample)(input)

        conv_3_3 = _bn_relu_conv(nb_filter=nb_filter, nb_row=3, nb_col=3)(conv_1_1)
        residual = _bn_relu_conv(nb_filter=nb_filter * 4, nb_row=1, nb_col=1)(conv_3_3)
        return _shortcut(input, residual)

    return f

In [0]:
def _handle_dim_ordering():
    global ROW_AXIS
    global COL_AXIS
    global CHANNEL_AXIS
    if K.common.image_dim_ordering() == 'tf':
        ROW_AXIS = 1
        COL_AXIS = 2
        CHANNEL_AXIS = 3
    else:
        CHANNEL_AXIS = 1
        ROW_AXIS = 2
        COL_AXIS = 3

In [0]:
def _get_block(identifier):
    if isinstance(identifier, six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier

**Functions to get a ResNet**

In [0]:
class ResnetBuilder(object):
    @staticmethod
    def build(input_shape, num_outputs, block_fn, repetitions):
        # in case of: build_resnet_34: input_shape == (img_channels, img_rows, img_cols)
        #                              repetitions == [3, 4, 6, 3]
        """Builds a custom ResNet like architecture.

        Args:
            input_shape: The input shape in the form (nb_channels, nb_rows, nb_cols)
            num_outputs: The number of outputs at final softmax layer
            block_fn: The block function to use. This is either `basic_block` or `bottleneck`.
                The original paper used basic_block for layers < 50
            repetitions: Number of repetitions of various block units.
                At each block unit, the number of filters are doubled and the input size is halved

        Returns:
            The keras `Model`.
        """
        _handle_dim_ordering()
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")

        # Permute dimension order if necessary
        if K.common.image_dim_ordering() == 'tf':
            input_shape = (input_shape[1], input_shape[2], input_shape[0])

        # Load function from str if needed.
        # block_fn is a function and its value will be basicblock
        # or bottleneck depending on the network. 
        block_fn = _get_block(block_fn)

        # This is the first layer.
        input = Input(shape=input_shape)
        # _conv_bn_relu is used to build a [conv->bn->relu] block
        conv1 = _conv_bn_relu(nb_filter=64, nb_row=7, nb_col=7, subsample=(2, 2))(input)    # 64 filters 7x7 with /2 downsampling
        pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)
        block = pool1

        nb_filter = 64
        for i, r in enumerate(repetitions):
            # Pass here(in _residual_block) the block_fn function because depending on
            # the resnet architecture the residual block will be bottleneck 
            # or basic block.
            # This function creates a number of residual blocks equal to repetitions.
            # _residual_block: returns a function pointer that we execute by 
            # giving it the block as input.
            # nb_filter: the number of filters at each layer
            block = _residual_block(block_fn, nb_filter=nb_filter, repetitions=r, is_first_layer=(i == 0))(block)
            
            # At the end of each "macro-block" duplicate the number of filter, 
            # this will be the number of filters in the next macro-block.
            # e.g. number of filters at each macro-block of a 
            # resnet_34:  64 -> 128 -> 256 -> 512 
            nb_filter *= 2

        # Last activation
        block = _bn_relu(block)

        #block_norm = BatchNormalization(mode=0, axis=CHANNEL_AXIS)(block)
        block_norm = BatchNormalization(axis=CHANNEL_AXIS)(block)
        block_output = Activation("relu")(block_norm)

        # Classifier block
        block_shape = K.int_shape(block)
        pool2 = AveragePooling2D(pool_size=(block_shape[ROW_AXIS], block_shape[COL_AXIS]),
                                 strides=(1, 1))(block_output)
        flatten1 = Flatten()(pool2)
        dense = Dense(units=num_outputs, kernel_initializer="he_normal", activation="softmax")(flatten1)
        #dense = Dense(output_dim=num_outputs, W_regularizer=l2(0.01), init="he_normal", activation="linear")(flatten1)

        # Keras Model: now we create the instance of the model given 
        # an input tensor and an output tensor.
        # This model will include all layers needed to compute dense given input.
        model = Model(inputs=input, outputs=dense)
        return model
        

    @staticmethod
    def build_resnet_test(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [1, 1, 1, 1])

    @staticmethod
    def build_resnet_18(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])

    @staticmethod
    def build_resnet_34(input_shape, num_outputs):
        # input_shape == (img_channels, img_rows, img_cols)
        return ResnetBuilder.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])

    @staticmethod
    def build_resnet_50(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])

    @staticmethod
    def build_resnet_101(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])

    @staticmethod
    def build_resnet_152(input_shape, num_outputs):
        return ResnetBuilder.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])