# Basics | Attention in Computer Vision

By [Akshaj Verma](https://akshajverma.com)

This notebook takes you through the different types of attention methods wrt computer vision using PyTorch.

In [1]:
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, utils

## Self Attention for Images

Let's define a tensor that we obtain after passing an image throught multiple conv layers. 

Let the size of this tensor be `(4, 5, 5)`. This means that our image (latent representation after mulitple conv operations) is of size `(5 x 5)` and has `4` channels.

![Self attention in SAGAN paper](../../assets/sagan_att.png)

[Reference](https://arxiv.org/pdf/1905.08008v1.pdf)

In [2]:
img = [float(i) for i in range(100)]
img = torch.tensor(img)

In [3]:
img = img.view([4, 5, 5])
img

tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.],
         [15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.]],

        [[25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44.],
         [45., 46., 47., 48., 49.]],

        [[50., 51., 52., 53., 54.],
         [55., 56., 57., 58., 59.],
         [60., 61., 62., 63., 64.],
         [65., 66., 67., 68., 69.],
         [70., 71., 72., 73., 74.]],

        [[75., 76., 77., 78., 79.],
         [80., 81., 82., 83., 84.],
         [85., 86., 87., 88., 89.],
         [90., 91., 92., 93., 94.],
         [95., 96., 97., 98., 99.]]])

We `unsqueeze(0)` to add a dimension of 1 at index 0. This dimension corresponds to the batch size. 
We do it because `nn.Conv2d()` requires it.

In [4]:
input_img = img.unsqueeze(0)
input_img.shape

torch.Size([1, 4, 5, 5])

In [15]:
def self_attention_module(input_img, in_channels, k, batch_size, input_size):
    cnn_f = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//k, kernel_size=1, stride=1)
    cnn_g = nn.Conv2d(in_channels=in_channels, out_channels=in_channels//k, kernel_size=1, stride=1)
    cnn_h = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)
#     cnn_o = nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1)

    
    f = cnn_f(input_img) # B x C x H x W
    g = cnn_g(input_img) # B x C x H x W
    h = cnn_h(input_img) # B x C x H x W
    
    f = f.view(batch_size, in_channels//k, -1) # B x C/k x N
    g = g.view(batch_size, in_channels//k, -1) # B x C/k x N
    h = h.view(batch_size, in_channels, -1) # B x C x N
    
    ft = f.permute(0, 2, 1) # B x N x C/k
    
    s = torch.bmm(ft, g) # B x N x N
    
    b = F.softmax(s, dim = -1) # B x N x N
    
    hb = torch.bmm(h, b) # B x C x N
    
#     hb_reshaped = hb.view(batch_size, in_channels, input_size, input_size) # B x C x H x W
#     o = cnn_o(hb_reshaped) # B x C x H x W
    
    
    
    return hb

In [16]:
sa1 = self_attention_module(input_img=input_img, in_channels=4, k=2, batch_size=1, input_size=5)
sa2 = self_attention_module(input_img=input_img, in_channels=4, k=2, batch_size=1, input_size=5)
sa3 = self_attention_module(input_img=input_img, in_channels=4, k=2, batch_size=1, input_size=5)

In [17]:
print(sa1.shape)

torch.Size([1, 4, 25])


In [28]:
multi_head_concat = torch.cat([sa1, sa2, sa3], dim = -1)
multi_head_concat.shape

torch.Size([1, 4, 75])

In [29]:
w0 = nn.Parameter(torch.zeros(1, 75, 25))
w0.shape

torch.Size([1, 75, 25])

In [30]:
agg_op = torch.bmm(multi_head_concat, w0)
agg_op.shape

torch.Size([1, 4, 25])