In [1]:
import torch
from einops.layers.torch import Rearrange
from einops import repeat, rearrange

## Resources 
- [Vision Transformers](https://arxiv.org/pdf/2010.11929.pdf)
- [Attention by lucidrains](https://github.com/lucidrains/vit-pytorch/blob/4b8f5bc90002a5506d765c811b554760d8dd6ee7/vit_pytorch/vit.py#L35)
- [Transformers by lucidrains](https://github.com/lucidrains/vit-pytorch/blob/4b8f5bc90002a5506d765c811b554760d8dd6ee7/vit_pytorch/vit.py#L67)
- [Attention is all you need](https://arxiv.org/pdf/1706.03762.pdf)


Lets understand `Vision Transformer` in 6 simple steps 

## step-1
The following are the inputs required by the vision transformer. 
- input image size. 
- patch_size 

from this we can calculate the number of patches and patch_dimension in the following way

> If the image is of size (H, W) and our patch size (PxP). we reshape the image to $I^{N, P^2xC}$ where C is the number of channels and N= HxW/$P^2$

In [2]:
image = torch.randn((224*224*3)).reshape((224, 224, 3))
image_shape = image.shape
patch_size = (14, 14)

image_height, image_width, channels = image_shape
patch_height, patch_width = patch_size

In [3]:
num_patches = (image_height // patch_height) * (image_width // patch_width)
patch_dim = channels * patch_height * patch_width

print(f"total_patches: {num_patches}")
print(f"patch_dim: {patch_dim}")

total_patches: 256
patch_dim: 588


> Reshape the input image to [batch_size, total_patches, patch_dim]

In [4]:
mr = Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width)
out = mr(image.permute((2, 0, 1)).unsqueeze(0))
out.shape

torch.Size([1, 256, 588])

## step-2
> The Transformer uses constant latent vector size D through all of its layers, so we flatten the patches and map to D dimensions with a trainable linear projection. We refer to the output of this projection as the `patch embeddings`.

In [5]:
embed_dim = 128
embed = torch.nn.Linear(588, 128)
with torch.no_grad():
    embed_out = embed(out)
print(embed_out.shape)

torch.Size([1, 256, 128])


## step-3
> Similar to BERT’s [class] token, we prepend a learnable embedding to the sequence of embedded patches ($z^0_{0}$ = xclass), whose state at the output of the Transformer encoder ($z^0_{L}$) serves as the image representation y (Eq. 4). 

In [6]:
cls_token = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
cls_token.shape

torch.Size([1, 1, 128])

we have to repeat this `cls_token` batch_size times and concat with our `step-2` output
> we can use one of `torch.repeat` or `eniops repeat` 

In [7]:
%%time
b, n, _ = embed_out.shape
cls_tokens = cls_token.repeat(b, 1, 1)
cls_tokens.shape

CPU times: user 316 µs, sys: 107 µs, total: 423 µs
Wall time: 284 µs


torch.Size([1, 1, 128])

In [8]:
%%time
b, n, _ = embed_out.shape
cls_tokens = repeat(cls_token, '1 1 d -> b 1 d', b = b) #we repeat N of batch times 
cls_tokens.shape

CPU times: user 779 µs, sys: 577 µs, total: 1.36 ms
Wall time: 901 µs


torch.Size([1, 1, 128])

In [9]:
x = torch.cat((cls_tokens, embed_out), dim=1)
x.shape

torch.Size([1, 257, 128])

## step4
> Position embeddings are added to the patch embeddings to retain positional information. We use standard learnable 1D position embeddings, The resulting sequence of embedding vectors serves as input to the encoder.

In [10]:
pos_embedding = torch.nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
pos_embedding.shape

torch.Size([1, 257, 128])

In [11]:
x += pos_embedding[:, :(n + 1)]
x.shape

torch.Size([1, 257, 128])

> Add dropout if required. 

## step-5
How a transformer works?
- [transformer is self-attention, multi-head] x blocks
- each transformer block is shown below
- it has two norm layer, one attention layer and one feed-forward layer. residual is applied after attention and after `ffn`

To be more specific
> 5.1) `y1 = Norm(x)`

> 5.2) `z1 = Attention(y1)`

> 5.3) `z1=+x (Residual layer)`

> 5.4) `y = Norm(z1)`

> 5.5) `z = FF(y)`

> 5.6) `x= z+z1 (residual layer)`

> 5.7) combine everything into one transformer block output is x which is sent to next res block


<img src="../images/transformer_block.png" alt="alt text" width="125" align="left"/>

we will apply each step here. 

### step 5.1

In [12]:
norm1 = torch.nn.LayerNorm(embed_dim)

with torch.no_grad():
    y1 = norm1(x)
y1.shape

torch.Size([1, 257, 128])

### 5.2 Attention
This is the most important aspect of Transformer. It involves two parts
- self-attention
- multi-head

Attention block has the following attributes
- dim 
- heads: total number of heads
- dim_head: dimension of each head
- dropout: optional after `qk`


#### self-attention
- we take the `norm vector` obtained in 5.1 and linearly project this tensor into 3 tensor `q`, `k`, `v` using a feed forward neural network. 
- then self-attention is just a dot product of q and k. scaled by $\sqrt(d_{k}$. softmax is applied on this output. The output again is a dot product with v. The [vaswani et al](https://arxiv.org/pdf/1706.03762.pdf) has insights (section 3.1) on 
- why dot product is choosen over additive?
- why we scale the outputs of qk? 

$$
Attention(Q, K, V ) = softmax(QK^{T}/\sqrt{d_{k}})V
$$

<img src="../images/attn_qkv.png" alt="alt text" width="200" align="left"/>


In [13]:
class AttentionHead(torch.nn.Module):
    def __init__(self, dim, dim_head, dropout=0.0):
        super().__init__()
        self.dim = dim 
        self.dim_head = dim_head
        self.scale = dim_head**(-0.5)
        self.attend = torch.nn.Softmax(dim = -1)
        self.to_qkv = torch.nn.Linear(self.dim, self.dim_head*3, bias = False)
        self.drop = torch.nn.Dropout(dropout)
    
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = qkv
        qk = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        qk_soft = self.attend(qk)
        qk_soft = self.drop(qk_soft)
        out = torch.matmul(qk_soft, v)
        return out

In [14]:
ah = AttentionHead(128, 64)
ah(y1).shape

torch.Size([1, 257, 64])

<img src="../images/attn_multihead.png" alt="alt text" width="200" align="left"/>

In [15]:
class AttentionAllHead(torch.nn.Module):
    def __init__(self, dim, heads, dim_head, dropout=0.0):
        super().__init__()
        self.heads = heads
        project_out = not (heads == 1 and dim_head == dim)
        for i in range(self.heads):
            setattr(self, f"head_{i}", AttentionHead(dim, dim_head, dropout))
        #self.attnhead = [AttentionHead(dim, dim_head, dropout) for i in range(self.heads)]
        inner_dim = self.heads * dim_head
        self.to_out = torch.nn.Sequential(
            torch.nn.Linear(inner_dim, dim),
            torch.nn.Dropout(dropout)
        ) if project_out else nn.Identity()
    
    def forward(self, x):
        dd = []
        for i in range(self.heads):
            out = getattr(self, f"head_{i}")(x)
            dd.append(out.unsqueeze(1))
        #dd = [t(x).unsqueeze(1) for t in self.attnhead]
        dd = torch.cat(dd, dim=1)
        out = rearrange(dd, 'b h n d -> b n (h d)')
        return self.to_out(out)

> I have kept the implementation little simple. this [repo](https://github.com/lucidrains/vit-pytorch/blob/4b8f5bc90002a5506d765c811b554760d8dd6ee7/vit_pytorch/vit.py#L47) has combined both our functions into one called `Attention` 

In [16]:
dim = 128
heads = 8
dim_head = 64
alh = AttentionAllHead(128, 8, 64, 0.0)
z1 = alh(y1)
z1.shape

torch.Size([1, 257, 128])

### 5.3 residual 

In [17]:
z1+=x 
z1.shape

torch.Size([1, 257, 128])

### 5.4 Norm2 

In [18]:
norm2 = torch.nn.LayerNorm(embed_dim)

with torch.no_grad():
    y = norm2(z1)
y.shape

torch.Size([1, 257, 128])

### 5.5 Feadforward neural network

In [19]:
class FeedForward(torch.nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim, hidden_dim),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, dim),
            torch.nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [20]:
ff = FeedForward(128, hidden_dim=2048, dropout=0.1)
with torch.no_grad():
    z = ff(y)
z.shape

torch.Size([1, 257, 128])

### 5.6 Another residual layer

In [21]:
x = z+z1
x.shape

torch.Size([1, 257, 128])

> Combine everything

## step-6 Transformer block 

In [22]:
class TransformerBlock(torch.nn.Module):
    def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.dim = dim
        self.norm1 = torch.nn.LayerNorm(dim)
        self.norm2 = torch.nn.LayerNorm(dim)
        self.attn = AttentionAllHead(dim, heads, dim_head, dropout)
        self.ff = FeedForward(dim, hidden_dim=mlp_dim, dropout=dropout)
    
    def forward(self, x):
        y1 = self.norm1(x)
        z1 = self.attn(y1)
        z1+=x
        y = self.norm2(z1)
        z = self.ff(y)
        x = z+z1
        return x 

In [23]:
tb = TransformerBlock(dim=128, heads=8, dim_head=64, mlp_dim=2048)
tb(embed_out).shape

torch.Size([1, 256, 128])

> The vision transformer uses 6 transformer blocks. 

we can use `vit_pytorch` to load a transformer network 

In [25]:
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)
preds.shape

torch.Size([1, 1000])

In [26]:
sum([params.numel() for name, params in v.named_parameters()])

54622184

> ViT has 54 million params :( 