Skip to content

Commit

Permalink
Update DPN
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 committed Aug 8, 2017
1 parent 361b27b commit f9eaf9a
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions dual_path_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,10 @@ def _bn_relu_conv_block(input, filters, kernel=(3, 3), stride=(1, 1), weight_dec
'''
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1

x = BatchNormalization(axis=channel_axis)(input)
x = Activation('relu')(x)
x = Conv2D(filters, kernel, padding='same', use_bias=False, kernel_initializer='he_normal',
kernel_regularizer=l2(weight_decay), strides=stride)(x)
kernel_regularizer=l2(weight_decay), strides=stride)(input)
x = BatchNormalization(axis=channel_axis)(x)
x = Activation('relu')(x)
return x


Expand All @@ -278,15 +278,12 @@ def _grouped_convolution_block(input, grouped_channels, cardinality, strides, we

if cardinality == 1:
# with cardinality 1, it is a standard convolution
x = BatchNormalization(axis=channel_axis)(init)
x = Activation('relu')(x)
x = Conv2D(grouped_channels, (3, 3), padding='same', use_bias=False, strides=strides,
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(x)
kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay))(init)
x = BatchNormalization(axis=channel_axis)(x)
x = Activation('relu')(x)
return x

input = BatchNormalization(axis=channel_axis)(init)
input = Activation('relu')(input)

for c in range(cardinality):
x = Lambda(lambda z: z[:, :, :, c * grouped_channels:(c + 1) * grouped_channels]
if K.image_data_format() == 'channels_last' else
Expand All @@ -298,6 +295,8 @@ def _grouped_convolution_block(input, grouped_channels, cardinality, strides, we
group_list.append(x)

group_merge = concatenate(group_list, axis=channel_axis)
group_merge = BatchNormalization(axis=channel_axis)(group_merge)
group_merge = Activation('relu')(group_merge)
return group_merge


Expand Down Expand Up @@ -479,3 +478,7 @@ def _create_dpn(nb_classes, img_input, include_top, initial_conv_filters,
x = Lambda(lambda z: 0.5 * z)(x)

return x

if __name__ == '__main__':
model = DPN92((224, 224, 3))
model.summary()

0 comments on commit f9eaf9a

Please sign in to comment.