In [3]:
import torch
import torchvision
import torch.nn as nn

### Squeeze And Excitation [[paper](https://arxiv.org/abs/1709.01507)]

In [4]:
class SqueezeExcitationModule(nn.Module):
    def __init__(self, c_in, channel_reduction_factor):
        super(SqueezeExcitationModule, self).__init__()

        self.c_in = c_in
        self.channel_reduction_factor = channel_reduction_factor

        self.squeeze = nn.AdaptiveAvgPool2d(output_size=(1, 1))

        self.excitation = nn.Sequential(
            nn.Conv2d(
                in_channels=self.c_in,
                out_channels=self.c_in // self.channel_reduction_factor,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels=self.c_in // self.channel_reduction_factor,
                out_channels=self.c_in,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
            nn.Sigmoid(),
        )

    def forward(self, x):
        o = self.squeeze(x)
        o = self.excitation(x)
        x = x * o

        return x


class ResSqexModule(nn.Module):
    def __init__(self, c_in, c_out, channel_reduction_factor):
        super(ResSqexModule, self).__init__()

        self.c_in = c_in
        self.c_out = c_out
        self.channel_reduction_factor = channel_reduction_factor

        self.conv_start = nn.Sequential(
            nn.Conv2d(
                in_channels=self.c_in,
                out_channels=self.c_in,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(num_features=self.c_in),
            nn.ReLU(),
        )

        self.sqex = SqueezeExcitationModule(
            c_in=self.c_in, channel_reduction_factor=self.channel_reduction_factor
        )

        self.conv_end = nn.Sequential(
            nn.Conv2d(
                in_channels=self.c_in,
                out_channels=self.c_out,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.BatchNorm2d(num_features=self.c_out),
            nn.ReLU(),
        )

    def forward(self, x):
        o = self.conv_start(x)
        o = self.sqex(o)
        o = o + x
        o = self.conv_end(o)

        return o

### CBAM: Convolutional Block Attention Module [[paper](https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf)]

In [5]:
class CamModule(nn.Module):
    def __init__(self, c_in, channel_reduction_factor=2):
        super(CamModule, self).__init__()
        
        self.c_in = c_in
        self.channel_reduction_factor = channel_reduction_factor
        
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.max_pool = nn.AdaptiveMaxPool2d(output_size=(1, 1))
        
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels=self.c_in, out_channels=self.c_in//self.channel_reduction_factor, kernel_size=1, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.c_in//self.channel_reduction_factor, out_channels=self.c_in, kernel_size=1, stride=1, padding=0),
        )
        
    def forward(self, x):
        g = self.avg_pool(x)
        m = self.max_pool(x)
        
        g_out = self.conv_block(g)
        m_out = self.conv_block(m)
        
        o = g_out + m_out
        o = torch.sigmoid(o)
        
        x *= o
        
        return x
    
class SamModule(nn.Module):
    def __init__(self):
        super(SamModule, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        o = self.channel_pool(x)
        o = self.conv(o)
        o = torch.sigmoid(o)
        
        x *= o
        
        return x
    
    def channel_pool(self, x):
        return torch.cat((torch.max(x, dim=1)[0].unsqueeze(1), torch.mean(x, dim=1).unsqueeze(1)), dim=1)


class ResCbamModule(nn.Module):
    def __init__(self, c_in, c_out, channel_reduction_factor):
        super(ResCbamModule, self).__init__()
        
        self.c_in = c_in
        self.c_out = c_out
        self.channel_reduction_factor = channel_reduction_factor
        
        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channels=self.c_in, out_channels=self.c_in, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_channels=self.c_in, num_groups=self.c_in//2),
            nn.ReLU()
        )        
        self.cam = CamModule(c_in=self.c_in, channel_reduction_factor=self.channel_reduction_factor)
        self.sam = SamModule()
        
        self.conv_end = nn.Sequential(
            nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_channels=self.c_out, num_groups=self.c_out//2),
            nn.ReLU()
        )
        
        
    def forward(self, x):
        o = self.conv_start(x)
        o = self.cam(o)
        o = self.sam(o)
        o += x
        o = self.conv_end(o)
        
        return o

### Self-Attention [[paper](https://arxiv.org/abs/1711.07971)]

In [6]:
class SelfAttentionModule(nn.Module):
    def __init__(self, c_in, k):
        super(SelfAttentionModule, self).__init__()
        
        self.c_in = c_in
        self.k = k 
        
        self.cnn_f = nn.Conv2d(in_channels=c_in, out_channels=c_in//k, kernel_size=1, stride=1)
        self.cnn_g = nn.Conv2d(in_channels=c_in, out_channels=c_in//k, kernel_size=1, stride=1)
        self.cnn_h = nn.Conv2d(in_channels=c_in, out_channels=c_in, kernel_size=1, stride=1)
        
        self.gamma = nn.Parameter(torch.zeros(1))
        
        
    def forward(self, x):
        
        batch_size = x.shape[0]
        
        f = self.cnn_f(x) # B x C x H x W
        g = self.cnn_g(x) # B x C x H x W
        h = self.cnn_h(x) # B x C x H x W

        f = f.view(batch_size, self.c_in//self.k, -1).permute(0, 2, 1) # B x N x C/k
        g = g.view(batch_size, self.c_in//self.k, -1) # B x C/k x N
        h = h.view(batch_size, self.c_in, -1) # B x C x N

        s = torch.bmm(f, g) # B x N x N

        b = F.softmax(s, dim = 1) # B x N x N

        hb = torch.bmm(h, b) # B x C x N
        
        hb = hb.view(*x.shape) # B x C x H x W

        o = self.gamma * hb + x # B x C x H x W
        
        
        return o
    
    
class ResSattnModule(nn.Module):
    def __init__(self, c_in, c_out, channel_reduction_factor):
        super(ResSattnModule, self).__init__()
        
        self.c_in = c_in
        self.c_out = c_out
        self.channel_reduction_factor = channel_reduction_factor
        
        self.conv_start = nn.Sequential(
            nn.Conv2d(in_channels=self.c_in, out_channels=self.c_in, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_channels=self.c_in, num_groups=self.c_in//2),
            nn.ReLU()
        )
        
        self.sattn = SelfAttentionModule(c_in = self.c_in, k = self.channel_reduction_factor)
        
        self.conv_end = nn.Sequential(
            nn.Conv2d(in_channels=self.c_in, out_channels=self.c_out, kernel_size=3, stride=1, padding=1),
            nn.GroupNorm(num_channels=self.c_out, num_groups=self.c_out//2),
            nn.ReLU()
        )
        
                
    def forward(self, x):
        o = self.conv_start(x)
        o = self.sattn(o)
        o += x
        o = self.conv_end(o)

        
        return o