AttaNet: Attention-Augmented Network for Fast and Accurate Scene Parsing

official pytorch: https://github.com/songqi-github/AttaNet

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
class SAM(nn.Module):
    def __init__(self, in_chans=512, h=64, w=64, r=8):
        super(SAM, self).__init__()
        
        self.conv2d = nn.Conv2d(in_chans, in_chans, kernel_size=1)
        
        self.conv2dq = nn.Conv2d(in_chans, in_chans//r, kernel_size=1)
        self.conv2dk = nn.Conv2d(in_chans, in_chans//r, kernel_size=1)
        
        self.adapavgpoolk = nn.AdaptiveAvgPool2d((h,1))
        self.adapavgpoolv = nn.AdaptiveAvgPool2d((1,w))
    
    def forward(self, x):
        b, c, h, w = x.shape
        #
        V = self.conv2d(x)
        Q = self.conv2dq(x)
        K = self.conv2dk(x)
        
        #
        Q = Q.contiguous().view(b,-1,h*w)
        K = self.adapavgpoolk(K).squeeze(-1)
        #
        A = torch.einsum('bcn,bcw->bnw',Q,K)
        A = F.softmax(A, dim=1)
        #
        V = self.adapavgpoolv(V).squeeze(-2)
        #
        out_F = torch.einsum('bnw,bcw->bcn',A,V).contiguous().view(b,-1,h,w) + x
        return out_F
        
if __name__ == "__main__":
    
    inp = torch.randn(2, 512, 64, 64)
    module = SAM()
    out = module(inp)
    print(out.shape)

torch.Size([2, 512, 64, 64])


In [30]:
# official implement
class StripAttentionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(StripAttentionModule, self).__init__()
        self.conv1 = ConvBNReLU(in_chan, 64, ks=1, stride=1, padding=0)
        self.conv2 = ConvBNReLU(in_chan, 64, ks=1, stride=1, padding=0)
        self.conv3 = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.softmax = nn.Softmax(dim=1)

        self.init_weight()

    def forward(self, x):
        q = self.conv1(x)
        batchsize, c_middle, h, w = q.size()
        q = F.avg_pool2d(q, [h, 1])
        q = q.view(batchsize, c_middle, -1).permute(0, 2, 1)

        k = self.conv2(x)
        k = k.view(batchsize, c_middle, -1)
        attention_map = torch.bmm(q, k)
        attention_map = self.softmax(attention_map)

        v = self.conv3(x)
        c_out = v.size()[1]
        v = F.avg_pool2d(v, [h, 1])
        v = v.view(batchsize, c_out, -1)

        augmented_feature_map = torch.bmm(v, attention_map)
        augmented_feature_map = augmented_feature_map.view(batchsize, c_out, h, w)
        out = x + augmented_feature_map
        return out

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params

In [29]:
class AFM(nn.Module):
    def __init__(self, in_chans_i=512, in_chans_j=1024):
        super(AFM, self).__init__()
        
        self.conv3x3a = nn.Conv2d(
            in_chans_i, 
            in_chans_i, 
            kernel_size=3, 
            stride=1, 
            padding=1
            )
        
        self.conv3x3b = nn.Conv2d(
            in_chans_j, 
            in_chans_i, 
            kernel_size=3, 
            stride=1, 
            padding=1
            )
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_chans_i*2, in_chans_i, kernel_size=1),
            nn.ReLU(True),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_chans_i, in_chans_i, kernel_size=1),
            nn.Sigmoid()
        )
        
        
    def forward(self, f_i, f_j):

        _f_i = F.interpolate(
            f_i, 
            size=f_j.shape[2:], 
            mode='bilinear', 
            align_corners=True
            )
        f_i = self.conv3x3a(f_i)
        _f_j = self.conv3x3b(f_j)

        f = torch.cat((_f_i, _f_j), dim=1)
        alpha = self.conv(f)
        out_f_i =  torch.mul(alpha, f_i)
        
        out_f_j = torch.mul(1-alpha, _f_j)
        
        out_f_i = F.interpolate(
            out_f_i, 
            size=f_j.shape[2:], 
            mode='bilinear', 
            align_corners=True
            )

        out = out_f_i + out_f_j
        return out
        
if __name__ == "__main__":
    
    f_i = torch.randn(2, 512, 128, 128)
    f_j = torch.randn(2, 1024, 64, 64)
    
    module = AFM()
    out = module(f_i, f_j)
    print(out.shape)
        

torch.Size([2, 512, 64, 64])


In [None]:
# official implement
class AttentionFusionModule(nn.Module):
    def __init__(self, in_chan, out_chan, *args, **kwargs):
        super(AttentionFusionModule, self).__init__()
        self.conv = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
        self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
        self.bn_atten = BatchNorm2d(out_chan)
        self.sigmoid_atten = nn.Sigmoid()

        self.init_weight()

    def forward(self, feat16, feat32):
        feat32_up = F.interpolate(feat32, feat16.size()[2:], mode='nearest')
        fcat = torch.cat([feat16, feat32_up], dim=1)
        feat = self.conv(fcat)

        atten = F.avg_pool2d(feat, feat.size()[2:])
        atten = self.conv_atten(atten)
        atten = self.bn_atten(atten)
        atten = self.sigmoid_atten(atten)
        return atten

    def init_weight(self):
        for ly in self.children():
            if isinstance(ly, nn.Conv2d):
                nn.init.kaiming_normal_(ly.weight, a=1)
                if not ly.bias is None: nn.init.constant_(ly.bias, 0)

    def get_params(self):
        wd_params, nowd_params = [], []
        for name, module in self.named_modules():
            if isinstance(module, (nn.Linear, nn.Conv2d)):
                wd_params.append(module.weight)
                if not module.bias is None:
                    nowd_params.append(module.bias)
            elif isinstance(module, BatchNorm2d):
                nowd_params += list(module.parameters())
        return wd_params, nowd_params