# [PyTorch] Attention Module on Local Features


In addition to the usual global average pooling, this notebook illustrate __an idea__ of adding an existing visual attention (channel and spatial) module for local features modelling. The attention module is the attention module from [DANet](https://arxiv.org/abs/1809.02983), implementation is their official implementation, adapted for this special used case.

With the attention module, the __boost__ to local CV is around 0.05 (~0.945 -> 0.950), LB is around 0.04 (0.942 (V2) -> 0.946 (V1)). All in a single-fold (out of 5 folds).

In this notebook, the backbone is efficientnet-b2, but the method can be adapted to larger backbone like the popular resnet200d or efficientnet-b7.

P.S: In the process of trying this method, I realized the quality of training has a huge impact on the model performance in this competition (most Kaggler probably know this already lol....).

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import sys
import os
import glob

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as a_transform
from torch.utils.data import DataLoader, Dataset

from albumentations.pytorch import ToTensorV2
effnet_pytorch = "../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master"
sys.path.append(effnet_pytorch)
from efficientnet_pytorch import EfficientNet

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)


# Attention Module

Similar idea with CBAM (inspiration from this [kernel](https://www.kaggle.com/ipythonx/tf-keras-ranzcr-multi-attention-efficientnet)) and [Dual Attention network](https://github.com/junfu1115/DANet).

In [None]:
class PAM_Module(nn.Module):
    """ Position attention module"""
    #Ref from SAGAN
    def __init__(self, in_dim):
        super(PAM_Module, self).__init__()
        self.chanel_in = in_dim

        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X (HxW) X (HxW)
        """
        m_batchsize, C, height, width = x.size()
        proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = torch.softmax(energy, dim=-1)
        proj_value = self.value_conv(x).view(m_batchsize, -1, width*height)

        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out


class CAM_Module(nn.Module):
    """ Channel attention module"""
    def __init__(self, in_dim):
        super(CAM_Module, self).__init__()
        self.chanel_in = in_dim
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self,x):
        """
            inputs :
                x : input feature maps( B X C X H X W)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x.size()
        proj_query = x.view(m_batchsize, C, -1)
        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = torch.softmax(energy_new, dim=-1)
        proj_value = x.view(m_batchsize, C, -1)

        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma*out + x
        return out


class CBAM(nn.Module):
    def __init__(self, in_channels):
        # def __init__(self):
        super(CBAM, self).__init__()
        inter_channels = in_channels // 4
        self.conv1_c = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU())
        
        self.conv1_s = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False),
                                     nn.BatchNorm2d(inter_channels),
                                     nn.ReLU())

        self.channel_gate = CAM_Module(inter_channels)
        self.spatial_gate = PAM_Module(inter_channels)

        self.conv2_c = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 3, padding=1, bias=False),
                                     nn.BatchNorm2d(in_channels),
                                     nn.ReLU())
        self.conv2_a = nn.Sequential(nn.Conv2d(inter_channels, in_channels, 3, padding=1, bias=False),
                                     nn.BatchNorm2d(in_channels),
                                     nn.ReLU())

    def forward(self, x):
        feat1 = self.conv1_c(x)
        chnl_att = self.channel_gate(feat1)
        chnl_att = self.conv2_c(chnl_att)

        feat2 = self.conv1_s(x)
        spat_att = self.spatial_gate(feat2)
        spat_att = self.conv2_a(spat_att)

        x_out = chnl_att + spat_att

        return x_out


# EfficientNet-b2 with soft attention

In [None]:
class EffNetWLF(nn.Module):

    def __init__(self, model_name, target_size=11):
        super().__init__()
        self.backbone = EfficientNet.from_name(model_name)

        self.backbone._dropout = nn.Dropout(0.1)
        n_features = self.backbone._fc.in_features
        self.backbone._fc = nn.Linear(n_features, target_size)

        self.local_fe = CBAM(n_features)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(nn.Linear(n_features + n_features, n_features),
                                        nn.BatchNorm1d(n_features),
                                        nn.Dropout(0.1),
                                        nn.ReLU(),
                                        nn.Linear(n_features, target_size))

    def forward(self, image):
        enc_feas = self.backbone.extract_features(image)

        # use default's global features
        global_feas = self.backbone._avg_pooling(enc_feas)
        global_feas = global_feas.flatten(start_dim=1)
        global_feas = self.dropout(global_feas)

        local_feas = self.local_fe(enc_feas)
        local_feas = torch.sum(local_feas, dim=[2,3])
        local_feas = self.dropout(local_feas)

        all_feas = torch.cat([global_feas, local_feas], dim=1)
        outputs = self.classifier(all_feas)
        return outputs


# Dataset and DataLoader

In [None]:
work_dir = "../input/ranzcr-clip-catheter-line-classification/"
df = pd.read_csv(os.path.join(work_dir, "sample_submission.csv"))
test_img_paths = glob.glob(os.path.join(work_dir, "test/*.jpg"))
print(len(test_img_paths))
model_weights = glob.glob("../input/effb2wlf/*.pth")
print(model_weights)

In [None]:
class TestDataset(Dataset):
    def __init__(self, img_paths, transform=None):
        self.img_paths = img_paths
        self.transform = transform
        self.clahe = cv2.createCLAHE(clipLimit=30.0, tileGridSize=(8, 8))
        
    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.img_paths[idx], cv2.IMREAD_GRAYSCALE)
        image = self.clahe.apply(image)
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        uid = self.img_paths[idx][:-4].split('/')[-1]
        return uid, image

## Dataloader

In [None]:
normalize = a_transform.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225], p=1.0, max_pixel_value=255.0)
test_transform = a_transform.Compose([a_transform.Resize(512, 512),
                                      normalize,
                                      ToTensorV2()], p=1.0)
test_ds = TestDataset(test_img_paths, test_transform)
dataloader = DataLoader(test_ds, batch_size=16, num_workers=4, shuffle=False, drop_last=False)

In [None]:
final_pred = np.empty((len(model_weights),len(test_ds), 11), dtype=np.float32)
for model_idx, each_w in enumerate(model_weights):
    print(f"running idx {model_idx}")
    model = EffNetWLF("efficientnet-b2")
    model = model.to(device)
    checkpoint = torch.load(f"{each_w}")
    model.load_state_dict(checkpoint["model"])
    uids = []
    all_pred = []
    for idx, (name, x_mb) in enumerate(dataloader):
        x = x_mb.to(device)
        x = torch.stack([x,x.flip(-1)],0) # hflip
        x = x.view(-1, 3, 512, 512)
        with torch.no_grad():
            pred = model(x)
        pred = pred.view(2, x_mb.size(0), -1).mean(0)
        uids.append(name)
        all_pred.append(torch.sigmoid(pred))
    all_pred = torch.cat(all_pred, dim=0).cpu().numpy()
    final_pred[model_idx] = all_pred
    
final_pred = final_pred.mean(axis=0)
print(final_pred.shape)

target_cols = df.columns[1:]
df[target_cols] = final_pred
df["StudyInstanceUID"] = np.concatenate(uids)
df.to_csv('submission.csv', index=False)
df.head()