# Embedding CBAM (Convolutional Block Attention Module) and CNN (Convolutional Neural Network) for classifying the MNIST hand-digits dataset

## 1. Import libraries

In [1]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

## 2. Define CBAM Pytorch models

In [2]:
# Define the Convolutional Block Attention Module (CBAM)
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1   = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.sigmoid(self.conv1(out))
        return out

class CBAM(nn.Module):
    def __init__(self, in_planes):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes)
        self.sa = SpatialAttention()

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

## 3. Define the CNN model (which uses the previously defined CBAM)

In [29]:
# Define the CNN model with CBAM
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3)
        self.cbam1 = CBAM(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3)
        self.cbam2 = CBAM(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3)
        self.cbam3 = CBAM(256)
        self.fc1 = nn.Linear(256, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        out = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        out = self.cbam1(out)
        out = torch.relu(torch.max_pool2d(self.conv2(out), 2))
        out = self.cbam2(out)
        out = torch.relu(torch.max_pool2d(self.conv3(out), 2))
        out = self.cbam3(out)
        out = out.flatten(1)  # Flatten the tensor
        out = torch.relu(self.fc1(out))
        out = self.fc2(out)
        return out

## 4. Load the MNIST dataset

In [None]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=False, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=False, transform=transform)

print(train_dataset)
print(test_dataset)

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )
Dataset MNIST
    Number of datapoints: 10000
    Root location: ./data
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
           )


## 5. Create dataloaders from the previous datasets

In [31]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

## 6. Model, loss function and optimizer initialization

In [40]:
# Initialize the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Will train with device {device}')
model = CNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.to(device)

Will train with device cuda


CNN(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
  (cbam1): CBAM(
    (ca): ChannelAttention(
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (max_pool): AdaptiveMaxPool2d(output_size=1)
      (fc1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu1): ReLU()
      (fc2): Conv2d(4, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (sigmoid): Sigmoid()
    )
    (sa): SpatialAttention(
      (conv1): Conv2d(2, 1, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (sigmoid): Sigmoid()
    )
  )
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
  (cbam2): CBAM(
    (ca): ChannelAttention(
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (max_pool): AdaptiveMaxPool2d(output_size=1)
      (fc1): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu1): ReLU()
      (fc2): Conv2d(8, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (sigmoid): Sigmoid()
    )
    (sa): Spati

## 7. Model training

In [37]:
# Train the model
for epoch in range(10):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                  .format(epoch+1, 10, i+1, len(train_loader), loss.item()))

Epoch [1/10], Step [100/938], Loss: 0.5849
Epoch [1/10], Step [200/938], Loss: 0.2238
Epoch [1/10], Step [300/938], Loss: 0.1793
Epoch [1/10], Step [400/938], Loss: 0.2797
Epoch [1/10], Step [500/938], Loss: 0.1381
Epoch [1/10], Step [600/938], Loss: 0.1300
Epoch [1/10], Step [700/938], Loss: 0.0552
Epoch [1/10], Step [800/938], Loss: 0.0626
Epoch [1/10], Step [900/938], Loss: 0.0576
Epoch [2/10], Step [100/938], Loss: 0.0914
Epoch [2/10], Step [200/938], Loss: 0.1597
Epoch [2/10], Step [300/938], Loss: 0.0597
Epoch [2/10], Step [400/938], Loss: 0.1022
Epoch [2/10], Step [500/938], Loss: 0.0636
Epoch [2/10], Step [600/938], Loss: 0.1937
Epoch [2/10], Step [700/938], Loss: 0.0701
Epoch [2/10], Step [800/938], Loss: 0.0433
Epoch [2/10], Step [900/938], Loss: 0.0345
Epoch [3/10], Step [100/938], Loss: 0.0278
Epoch [3/10], Step [200/938], Loss: 0.0126
Epoch [3/10], Step [300/938], Loss: 0.0169
Epoch [3/10], Step [400/938], Loss: 0.0268
Epoch [3/10], Step [500/938], Loss: 0.0352
Epoch [3/10

## 8. Model evaluation

In [39]:
# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy: {} %'.format(100 * correct / total))

Test Accuracy: 98.86 %
