## [ViTDet](https://arxiv.org/pdf/2203.16527.pdf) - The go to architecture for image foundation models 

ViTDet, as of Jan 2024 is the go to architecture for all the vision tasks. It is used in `segment-anything`. The [`ViTAE-Transformer`](https://github.com/ViTAE-Transformer) has SOTA on several tasks like semantic segmentation, object detection, human pose, matting, Remote sensing etc. Understanding this backbone architecture will help us in choosing optimal parameters based on the task. 

Original ViTDet was written to highlight the need for specialized architecture for object detection using transformers. In a way, I will call this a super-simplified `Swin Transformers` which basically removed the heirarical nature of the network, shifted windows etc.

Note: we will only talk about the backbone and leave the FPN based ablation studies to the reader. 

So the network is broadly divided as 
> [PatchEmbed] -> nx[blocks] -> [Neck]

Inside each block, we have 
- window attention 
- relative postional encoding. 

we will talk about all of these.

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import fastcore.all as fc
from PIL import Image
from functools import partial
from torchvision.transforms import RandomResizedCrop, RandomHorizontalFlip, Compose, ToTensor, ToPILImage

import matplotlib.pyplot as plt
plt.style.use("bmh")
%matplotlib inline

> Lets create an image of size 224x224 with a patch size of 32

In [None]:
img_size = 1024
patch_size = 32

## load and visualize an image

we load and use `coco val` data. For this blog purpose, u can pick up any image of your choice from the internet.

In [None]:
imgs = fc.L(fc.Path("coco/val2017/").glob("*.jpg"))
imgs

(#5000) [Path('coco/val2017/000000182611.jpg'),Path('coco/val2017/000000335177.jpg'),Path('coco/val2017/000000278705.jpg'),Path('coco/val2017/000000463618.jpg'),Path('coco/val2017/000000568981.jpg'),Path('coco/val2017/000000092416.jpg'),Path('coco/val2017/000000173830.jpg'),Path('coco/val2017/000000476215.jpg'),Path('coco/val2017/000000479126.jpg'),Path('coco/val2017/000000570664.jpg')...]

> The following are the standard transforms mentioned in the paper.

In [None]:
def transforms():
    return Compose([RandomResizedCrop(size=1024, scale=[0.4, 1], ratio=[0.75, 1.33], interpolation=2), 
                    RandomHorizontalFlip(p=0.5), 
                    ToTensor()])

In [None]:
def load_img(img_loc, transforms):
    img = Image.open(img_loc)
    return transforms(img)

load_img = partial(load_img, transforms=transforms())

In [None]:
img = load_img(imgs[1])
img.shape

torch.Size([3, 1024, 1024])

## Patch Embed

we will create patch embeddings for [3x32x32]. For this we can use a simple convolution layer with kernel and stride as patch_size

In [None]:
num_channels = 3
hidden_size = 768
projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
projection

Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))

In [None]:
pe = projection(img.unsqueeze(0))
pe.shape

torch.Size([1, 768, 32, 32])

> reshuffle the pixels. 

In [None]:
pe = pe.permute((0, 2, 3, 1))
pe.shape

torch.Size([1, 32, 32, 768])

> Now we have [32x32] = 1024 tokens with each token of 768 vectors. The positions of each token wrt to other is preserved using conv type structure.

> we can add positial encodings to these features as optional. 

## Transformer Blocks 

In each transformer block, we first apply windowing, Then calculate attention, reattach window blocks, apply mlp. The transformer block also has a few skip connection and normalization layers as shown below.

<img src="images/vitdet_block.png" width=250 height=200>

## Windowing
In the context of `ViTDet` windowing is optional and attention can be calculated on all the tokens. This type of attention is called `global attention`. But global attention is expensive as we have to calculate a matrix of 1024x1024 in this case. If the patch_size is much smaller this will quadaritcally increase in size making it very expensive to compute. So window attention is considered, 

- First the 32x32 matrix is divided into 8x8 (window_size) windows. So we will get a total of (32/8) * (32/8) = 16 windows, with each window having (8x8) 64 tokens. Attention is only calculated within these tokens making it a `local attention`.

<img src="images/vitdet_windows.png" width=400 height=400>

In [None]:
window_size = 8
batch_size, height, width, num_channels = pe.shape
wpe = pe.view(
        batch_size, height // window_size, window_size, width // window_size, window_size, num_channels
    )
wpe.shape

torch.Size([1, 4, 8, 4, 8, 768])

In [None]:
windows = wpe.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, num_channels)
windows.shape

torch.Size([16, 8, 8, 768])

In [None]:
windows = windows.view(-1, window_size*window_size, num_channels)
windows.shape

torch.Size([16, 64, 768])

## Attention
This is a simple attention as discussed in [`attention is all you need`](https://arxiv.org/pdf/1706.03762.pdf) paper. we will see step by step as follows 

<img src="images/vitdet_attention.png" width=200 height=200>

> we obtain q, k, v matrices by using MLP layers. 

In [None]:
dim = windows.shape[-1]
num_heads = 4
head_dim = dim // num_heads
scale = head_dim**-0.5
wq = [nn.Linear(dim, head_dim) for head in range(num_heads)]
wk = [nn.Linear(dim, head_dim) for head in range(num_heads)]
wv = [nn.Linear(dim, head_dim) for head in range(num_heads)]
wq, wk, wv

([Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True)],
 [Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True)],
 [Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True),
  Linear(in_features=768, out_features=192, bias=True)])

In [None]:
q = [i(windows) for i in wq]
k = [i(windows) for i in wk]
v = [i(windows) for i in wv]
[i.shape for i in q] ## 

[torch.Size([16, 64, 192]),
 torch.Size([16, 64, 192]),
 torch.Size([16, 64, 192]),
 torch.Size([16, 64, 192])]

In [None]:
q = torch.concatenate(q) # number of heads * windows
k = torch.concatenate(k)
v = torch.concatenate(v)
q.shape, k.shape, v.shape

(torch.Size([64, 64, 192]),
 torch.Size([64, 64, 192]),
 torch.Size([64, 64, 192]))

> Matmul of q and k and use scale 

In [None]:
attention_scores = (q @ k.transpose(-2, -1)) * scale
attention_scores.shape

torch.Size([64, 64, 64])

> Apply relative positional encodings

This is a separate topic of its own to discuss but essentially we will add positional encodings in each attention block instead of at the start as done in plain vanilla vit.

In [None]:
rel_pos_h = nn.Parameter(torch.zeros(2 * window_size - 1, head_dim))
rel_pos_w = nn.Parameter(torch.zeros(2 * window_size - 1, head_dim))
rel_pos_h.shape, rel_pos_w.shape

(torch.Size([15, 192]), torch.Size([15, 192]))

In [None]:
from transformers.models.vitdet.modeling_vitdet import add_decomposed_relative_positions

In [None]:
attention_scores = add_decomposed_relative_positions(
                attention_scores, q, rel_pos_h, rel_pos_w, (window_size, window_size), (window_size, window_size)
            )
attention_scores.shape

torch.Size([64, 64, 64])

> Apply softmax

In [None]:
attention_probs = attention_scores.softmax(dim=-1)
attention_probs.shape

torch.Size([64, 64, 64])

> Multiply by key vectors 

In [None]:
hidden_state = attention_probs @ v
hidden_state.shape

torch.Size([64, 64, 192])

In [None]:
hidden_state = hidden_state.view(16, num_heads, window_size, window_size, -1)
hidden_state = hidden_state.permute(0, 2, 3, 1, 4)
hidden_state = hidden_state.reshape(16, window_size, window_size, -1)
hidden_state.shape

torch.Size([16, 8, 8, 768])

> Add projection layer 

In [None]:
proj = nn.Linear(dim, dim)
proj

Linear(in_features=768, out_features=768, bias=True)

In [None]:
attention_out = proj(hidden_state)
attention_out.shape

torch.Size([16, 8, 8, 768])

## Unwindowing 
un window the existing vector and get it in the form of (batch_size, tokens, embedding_dim)

In [None]:
pe = attention_out.view(-1, height // window_size, width // window_size, \
                       window_size, window_size, num_channels)
pe = pe.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, height, width, num_channels)
pe.shape

torch.Size([1, 32, 32, 768])

> we have done windowing - applied attention - unwindowed to get the vector.

> U can see that, the output vector size is same as input.

> if you don't want to global attention we can set the window_size as input size . in this case it is 32x32.

## Residual block 
In the network so far we have seen that attention is applied only within the windows. To learn accross windows, we did apply global attention in some of the layers. Global attention is considered to be expensive and is so applied only in few cases.
- The network is divided into 4 subsets. with each subset containing 6 blocks. So there are a total of 24 layers.
- At the end of each subset for the final block we apply global attention. 

This will reduce our computation and also allow tokens to learn outside windows. 

The authors of the paper also suggested a residual block with conv layer instead of global attention. The network looks as below with 1x1, 3x3 and 1x1 conv layers. this will allow the network to learn from all the tokens.

In [None]:
from transformers.models.vitdet.modeling_vitdet import VitDetResBottleneckBlock

In [None]:
class config:
    hidden_act = "gelu"
residual = VitDetResBottleneckBlock(config, in_channels=768, out_channels=768, bottleneck_channels=768//2)
residual

VitDetResBottleneckBlock(
  (conv1): Conv2d(768, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (norm1): VitDetLayerNorm()
  (act1): GELUActivation()
  (conv2): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (norm2): VitDetLayerNorm()
  (act2): GELUActivation()
  (conv3): Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (norm3): VitDetLayerNorm()
)

In [None]:
residual(pe.permute((0, 3, 1, 2))).shape

torch.Size([1, 768, 32, 32])

The crux of the network is only this. Now lets define all the parameters in `Segment Anything` backbone and see if everything is making sense. 

## Full scale network

In [None]:
from segment_anything.modeling.image_encoder import ImageEncoderViT

In [None]:
enc = ImageEncoderViT(img_size=1024,
                      patch_size=16, 
                      in_chans=3, 
                      embed_dim=768, 
                      depth=12, 
                      num_heads=12, 
                      mlp_ratio=4, 
                      out_chans=256, 
                      qkv_bias=True, 
                      norm_layer= torch.nn.modules.normalization.LayerNorm, 
                      act_layer=torch.nn.modules.activation.GELU, 
                      use_abs_pos=False, 
                      use_rel_pos=True, 
                      window_size=16,
                      global_attn_indexes=[2, 5, 8, 11])
enc

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=768, out_features=3072, bias=True)
        (lin2): Linear(in_features=3072, out_features=768, bias=True)
        (act): GELU(approximate='none')
      )
    )
  )
  (neck): Sequential(
    (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): LayerNorm2d()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (3): LayerNorm2d()
  )
)

In [None]:
enc(img.unsqueeze(0)).shape

torch.Size([1, 256, 64, 64])

## Ablation studies. 
- window attention is sufficient when aided with few global attention blocks. 
- using residual conv or global attention gave similar performance. Training and inference time is much lower when using residual conv. 
- Masked Autoencoders provide strong pre-trained backbones
- Compared to hierical backbones like MViT2 or Swin Transformers ViTDet works better. 
- Finally reaches 61.3 APbox on coco test set when pretrained with Imagenet 1k using MAE.

In the next series we will understand what is MAE and how we can apply them to plain Vanilla ViTs and ViTDet. Thank you.