1. https://www.kaggle.com/code/iommarz8/cnn-with-cbam-attention
2. https://youtu.be/O-eVuz5TU2E?si=C-9EOChIb-j4DNhb
3. https://github.com/Peachypie98/CBAM
4. https://paperswithcode.com/method/channel-attention-module

# Importing Libraries

In [76]:
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, GlobalMaxPooling2D, Reshape, Dense, Input
from tensorflow.keras.layers import Activation, Concatenate, Conv2D, Multiply
from keras.models import Model
from tensorflow.keras.layers import Lambda

# Channel Attention Module

In [63]:
def ChannelAttentionModule(input_feature, ratio=8):
    channels = input_feature.shape[-1]  # Extract number of channels
    
    ## Shared MLP
    shared_layer1 = Dense(channels // ratio, activation='relu', use_bias=False)
    shared_layer2 = Dense(channels, use_bias=False)
    
    ## Average Pooling.
    avg_pool = GlobalAveragePooling2D()(input_feature)
    avg_pool = shared_layer1(avg_pool)
    avg_pool = shared_layer2(avg_pool)
    
    ## Max Pooling
    max_pool = GlobalMaxPooling2D()(input_feature)
    max_pool = shared_layer1(max_pool)
    max_pool = shared_layer2(max_pool)

    # Add Avg and Max pools and apply sigmoid.
    features = avg_pool + max_pool
    features = Activation("sigmoid")(features)
    # Reshape features to match input_feature dimensions.
    features = Reshape((1, 1, channels))(features)

    ## Multiply with the input image features.
    refined_features = Multiply()([input_feature, features])

    return refined_features

# Spatial Attention Module

In [69]:
def SpatialAttentionModule(input_feature):
    ## Average Pooling.
    avg_pool = Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(input_feature)

    ## Max Pooling
    max_pool = Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(input_feature)
    ## Concatenate
    features = Concatenate(axis=-1)([avg_pool, max_pool])

    ## Convulational Layer.
    features = Conv2D(1, kernel_size=7, padding="same", activation="sigmoid")(features)
    features = Multiply()([input_feature, features])

    return features

# Convolutional Block Attention Module(CBAM)

In [74]:
def CBAM(input_feature):
    input_feature = ChannelAttentionModule(input_feature, ratio=8)
    input_feature = SpatialAttentionModule(input_feature)
    return input_feature

# Testing the CBAM Block

In [77]:
if __name__ == "__main__":
    inputs = Input(shape=(128, 128, 32))
    y = CBAM(inputs)
    model = Model(inputs, y)
    model.summary()


