-
Notifications
You must be signed in to change notification settings - Fork 0
/
cbam.py
124 lines (99 loc) · 4.61 KB
/
cbam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from tensorflow.keras.layers import *
from tensorflow.keras import backend as K
from tensorflow.keras.activations import sigmoid
def attach_attention_module(net, attention_module):
if attention_module == 'se_block': # SE_block
net = se_block(net)
elif attention_module == 'cbam_block': # CBAM_block
net = cbam_block(net)
else:
raise Exception("'{}' is not supported attention module!".format(attention_module))
return net
def se_block(input_feature, ratio=8):
"""Contains the implementation of Squeeze-and-Excitation(SE) block.
As described in https://arxiv.org/abs/1709.01507.
"""
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
channel = input_feature.shape[channel_axis]
se_feature = GlobalAveragePooling2D()(input_feature)
se_feature = Reshape((1, 1, channel))(se_feature)
assert se_feature.shape[1:] == (1, 1, channel)
se_feature = Dense(channel // ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature.shape[1:] == (1, 1, channel // ratio)
se_feature = Dense(channel,
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')(se_feature)
assert se_feature.shape[1:] == (1, 1, channel)
if K.image_data_format() == 'channels_first':
se_feature = Permute((3, 1, 2))(se_feature)
se_feature = multiply([input_feature, se_feature])
return se_feature
def cbam_block(cbam_feature, ratio=8):
"""Contains the implementation of Convolutional Block Attention Module(CBAM) block.
As described in https://arxiv.org/abs/1807.06521.
"""
cbam_feature = channel_attention(cbam_feature, ratio)
cbam_feature = spatial_attention(cbam_feature)
return cbam_feature
def channel_attention(input_feature, ratio=8):
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
channel = input_feature.shape[channel_axis]
shared_layer_one = Dense(channel // ratio,
activation='relu',
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
shared_layer_two = Dense(channel,
kernel_initializer='he_normal',
use_bias=True,
bias_initializer='zeros')
avg_pool = GlobalAveragePooling2D()(input_feature)
avg_pool = Reshape((1, 1, channel))(avg_pool)
assert avg_pool.shape[1:] == (1, 1, channel)
avg_pool = shared_layer_one(avg_pool)
assert avg_pool.shape[1:] == (1, 1, channel // ratio)
avg_pool = shared_layer_two(avg_pool)
assert avg_pool.shape[1:] == (1, 1, channel)
max_pool = GlobalMaxPooling2D()(input_feature)
max_pool = Reshape((1, 1, channel))(max_pool)
assert max_pool.shape[1:] == (1, 1, channel)
max_pool = shared_layer_one(max_pool)
assert max_pool.shape[1:] == (1, 1, channel // ratio)
max_pool = shared_layer_two(max_pool)
assert max_pool.shape[1:] == (1, 1, channel)
cbam_feature = Add()([avg_pool, max_pool])
cbam_feature = Activation('sigmoid')(cbam_feature)
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])
def spatial_attention(input_feature):
kernel_size = 7
if K.image_data_format() == "channels_first":
channel = input_feature.shape[1]
cbam_feature = Permute((2, 3, 1))(input_feature)
else:
channel = input_feature.shape[-1]
cbam_feature = input_feature
avg_pool = Lambda(lambda x: K.mean(x, axis=3, keepdims=True))(cbam_feature)
assert avg_pool.shape[-1] == 1
max_pool = Lambda(lambda x: K.max(x, axis=3, keepdims=True))(cbam_feature)
assert max_pool.shape[-1] == 1
concat = Concatenate(axis=3)([avg_pool, max_pool])
assert concat.shape[-1] == 2
cbam_feature = Conv2D(filters=1,
kernel_size=kernel_size,
strides=1,
padding='same',
activation='sigmoid',
kernel_initializer='he_normal',
use_bias=False)(concat)
assert cbam_feature.shape[-1] == 1
if K.image_data_format() == "channels_first":
cbam_feature = Permute((3, 1, 2))(cbam_feature)
return multiply([input_feature, cbam_feature])