[![deep-learning-notes](https://github.com/semilleroCV/deep-learning-notes/raw/main/assets/banner-notebook.png)](https://github.com/semilleroCV/deep-learning-notes)

## **CBAM: Convolutional Block Attention Module**

In [1]:
%%capture
#@title **Install required packages**

! pip install torchinfo

In [1]:
#@title **Importing libraries**

import torch # 2.3.0+cu121
import torch.nn as nn

import torchinfo #1.8.0

In [5]:
# Note: Not all dependencies have the __version__ method.

print(f"torch version: {torch.__version__}")

torch version: 2.3.1+cu121


## **CBAM code**

In [24]:
class ChannelAttention(nn.Module):
    def __init__(self, in_channels: int, ratio: int = 16):
        super().__init__()

        self.avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.max_pooling = nn.AdaptiveAvgPool2d(1)

        self.net = nn.Sequential(nn.Conv2d(in_channels, in_channels//ratio, 1),
                                 nn.ReLU(),
                                 nn.Conv2d(in_channels//ratio, in_channels, 1))
        
        self.act = nn.Sigmoid()

    def forward(self, x):
        avg_pool = self.net(self.avg_pooling(x))
        max_pool = self.net(self.max_pooling(x))

        out = self.act(avg_pool + max_pool)

        return out * x

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2)
        self.act = nn.Sigmoid()

    def forward(self, x):

        avg_x = torch.mean(x, dim=1, keepdim=True)
        max_x = torch.amax(x, dim=1, keepdim=True)

        out = torch.cat([avg_x, max_x], dim=1)
        out = self.act(self.conv(out))

        return out * x
    

class CBAM(nn.Module):
    def __init__(self, in_channels: int, ratio: int = 16):
        super().__init__()

        self.ca = ChannelAttention(in_channels, ratio)
        self.sa = SpatialAttention()

    def forward(self, x):

        ca_out = self.ca(x) 

        print(ca_out.shape)
        sa_out = self.sa(ca_out) 

        x = sa_out + x

        return x

In [25]:
cbam_module = CBAM(in_channels=256)
torchinfo.summary(cbam_module, (256, 32, 32), batch_dim = 0)

torch.Size([1, 256, 32, 32])


Layer (type:depth-idx)                   Output Shape              Param #
CBAM                                     [1, 256, 32, 32]          --
├─ChannelAttention: 1-1                  [1, 256, 32, 32]          --
│    └─AdaptiveAvgPool2d: 2-1            [1, 256, 1, 1]            --
│    └─Sequential: 2-2                   [1, 256, 1, 1]            --
│    │    └─Conv2d: 3-1                  [1, 16, 1, 1]             4,112
│    │    └─ReLU: 3-2                    [1, 16, 1, 1]             --
│    │    └─Conv2d: 3-3                  [1, 256, 1, 1]            4,352
│    └─AdaptiveAvgPool2d: 2-3            [1, 256, 1, 1]            --
│    └─Sequential: 2-4                   [1, 256, 1, 1]            (recursive)
│    │    └─Conv2d: 3-4                  [1, 16, 1, 1]             (recursive)
│    │    └─ReLU: 3-5                    [1, 16, 1, 1]             --
│    │    └─Conv2d: 3-6                  [1, 256, 1, 1]            (recursive)
│    └─Sigmoid: 2-5                      [1, 256, 1,