-
Notifications
You must be signed in to change notification settings - Fork 3
/
DeepLabV3plusSE_EfficientNet.py
112 lines (73 loc) · 3.87 KB
/
DeepLabV3plusSE_EfficientNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, UpSampling2D
from tensorflow.keras.layers import AveragePooling2D, Conv2DTranspose, Concatenate, Input, GlobalAveragePooling2D, Dense, Multiply
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB2
def SqueezeAndExcitation(inputs, ratio=8):
b, h, w, c = inputs.shape
x = GlobalAveragePooling2D()(inputs)
x = Dense(c//ratio, activation='relu', use_bias=False)(x)
x = Dense(c, activation='sigmoid', use_bias=False)(x)
x = Multiply()([inputs, x])
return x
def ASPP(image_features):
shape = image_features.shape
y_pool = AveragePooling2D(pool_size=(shape[1], shape[2]))(image_features)
y_pool = Conv2D(filters=128, kernel_size=1, padding='same', use_bias=False)(y_pool)
y_pool = BatchNormalization(name=f'bn_1')(y_pool)
y_pool = Activation('relu', name=f'relu_1')(y_pool)
y_pool = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y_pool)
y_1 = Conv2D(filters=128, kernel_size=1, padding='same', use_bias=False)(image_features)
y_1 = BatchNormalization(name=f'bn_2')(y_1)
y_1 = Activation('relu', name=f'relu_2')(y_1)
y_6 = Conv2D(filters=128, kernel_size=3, padding='same', dilation_rate = 6,use_bias=False)(image_features)
y_6 = BatchNormalization(name=f'bn_3')(y_6)
y_6 = Activation('relu', name=f'relu_3')(y_6)
y_12 = Conv2D(filters=128, kernel_size=1, padding='same', dilation_rate = 12,use_bias=False)(image_features)
y_12 = BatchNormalization(name=f'bn_4')(y_12)
y_12 = Activation('relu', name=f'relu_4')(y_12)
y_18 = Conv2D(filters=128, kernel_size=3, padding='same', dilation_rate = 6,use_bias=False)(image_features)
y_18 = BatchNormalization(name=f'bn_5')(y_18)
y_18 = Activation('relu', name=f'relu_5')(y_18)
y_c = Concatenate()([y_pool, y_1, y_6, y_12, y_18])
y = Conv2D(filters=128, kernel_size=1, padding='same', use_bias=False)(y_c)
y = BatchNormalization(name=f'bn_6')(y)
y = Activation('relu', name=f'relu_6')(y)
return y
def DeepLabV3PlusSE(inputs, classes=1):
inputs = Input(inputs)
base_model = EfficientNetB2(weights='imagenet', include_top=False, input_tensor=inputs, drop_connect_rate=0.25)
high_level_image_features = base_model.get_layer('block6a_expand_bn').output
high_level_image_features = SqueezeAndExcitation(high_level_image_features, ratio=16)
x_a = ASPP(high_level_image_features)
x_a = SqueezeAndExcitation(x_a, ratio=16)
x_a = UpSampling2D(size=4, interpolation='bilinear')(x_a)
low_level_image_features = base_model.get_layer('block3a_expand_bn').output
low_level_image_features = SqueezeAndExcitation(low_level_image_features, ratio=16)
x_b = Conv2D(filters=128, kernel_size=1, padding='same', use_bias=False)(low_level_image_features)
x_b = BatchNormalization(name=f'bn_7')(x_b)
x_b = Activation('relu', name=f'relu_7')(x_b)
x = Concatenate()([x_a, x_b])
x = SqueezeAndExcitation(x, ratio=16)
x = Conv2D(filters=128, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization(name=f'bn_8')(x)
x = Activation('relu', name=f'relu_8')(x)
x = Conv2D(filters=128, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization(name=f'bn_9')(x)
x = Activation('relu', name=f'relu_9')(x)
x = UpSampling2D(size=4, interpolation='bilinear')(x)
""" Outputs """
x = Conv2D(filters=classes, kernel_size=1, name='output_layer')(x)
if classes == 1:
x = Activation('sigmoid')(x)
else:
x = Activation('softmax')(x)
model = Model(inputs=inputs, outputs=x)
return model
def main():
model = DeepLabV3PlusSE(inputs=(1024,1024,3), classes=4)
model.summary()
if __name__== '__main__':
main()