# Vision Transformers

- https://github.com/google-research/vision_transformer
- https://arxiv.org/abs/2010.11929
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py


![img](https://tse2.mm.bing.net/th?id=OIP.cyCA4XEM1F4ueNS2ADCa9wAAAA&pid=Api) | ![im](https://raw.githubusercontent.com/google-research/vision_transformer/master/figure1.png)
---|---


### Attention, Self Attention and Multi-head Self Attention mechanisms

#### Attention mechanism

Refs:
- https://arxiv.org/pdf/2012.12877.pdf
- https://d2l.ai/chapter_attention-mechanisms/index.html
- https://d2l.ai/chapter_attention-mechanisms/multihead-attention.html
- https://d2l.ai/chapter_attention-mechanisms/self-attention-and-positional-encoding.html

The attention mechanism is based on a trainable associative memory with (key, value) vector pairs.

- Sequence of N query vectors: $Q \in \mathbb{R}^{N \times d}$ matched vs
- Set of $k$ key vectors: $K \in \mathbb{R}^{k \times d}$.
- These inner products are then scaled and normalized with a softmax function to obtain $k$ weights.   
- The output of the attention is the weighted sum of a set of $k$ value vectors (packed into $V \in \mathbb{R}^{k \times v}$.

$$
Attention(Q,K,V) = Softmax(Q K^T / \sqrt{d}) V \in \mathbb{R}^{N \times v}
$$

with dropout on attention weights.


A more generalized form of nonparametric attention pooling:
$$
f(q) = \sum_{i=1}^{k}\alpha(q, k_i) v_i
$$


In [1]:
import math
import torch

torch.manual_seed(1)

k = 4
d = 3
N = 5


Q = (torch.rand(N, d) > 0.7).float()
K = torch.eye(k, d)
V = torch.arange(k * d, dtype=torch.float32).reshape(k, d) * 4.5
Q

tensor([[1., 0., 0.],
        [1., 0., 1.],
        [0., 1., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [2]:
K

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.],
        [0., 0., 0.]])

In [3]:
V

tensor([[ 0.0000,  4.5000,  9.0000],
        [13.5000, 18.0000, 22.5000],
        [27.0000, 31.5000, 36.0000],
        [40.5000, 45.0000, 49.5000]])

In [4]:
W = torch.softmax(Q @ K.transpose(0, 1) / math.sqrt(d), dim=1)
W

tensor([[0.3726, 0.2091, 0.2091, 0.2091],
        [0.3202, 0.1798, 0.3202, 0.1798],
        [0.2091, 0.3726, 0.2091, 0.2091],
        [0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500]])

In [5]:
W @ V

tensor([[16.9410, 21.4410, 25.9410],
        [18.3538, 22.8538, 27.3538],
        [19.1470, 23.6470, 28.1470],
        [20.2500, 24.7500, 29.2500],
        [20.2500, 24.7500, 29.2500]])

#### Self Attention

Query,  key  and  values  matrices are  themselves  computed  from  a  sequence  of N input  vectors: $X \in \mathbb{R}^{N \times D}$

- $Q = X W_{Q} \in \mathbb{R}^{N \times d}$
- $K = X W_{K} \in \mathbb{R}^{N \times d}$
- $V = X W_{V} \in \mathbb{R}^{N \times d}$


$$
Head = SelfAttention(X, \{W_{Q}, W_{K}, W_{V}\}) = Softmax(X W_{Q} (X W_{K})^T / \sqrt{d}) (X W_{V}) 
$$


In [6]:
N = 5
D = 4
d = 3

x = torch.rand(N, D)
W_q = torch.rand(D, d)
W_k = torch.rand(D, d)
W_v = torch.rand(D, d)

Q = x @ W_q
K = x @ W_k
V = x @ W_v
Q.shape, K.shape, V.shape

(torch.Size([5, 3]), torch.Size([5, 3]), torch.Size([5, 3]))

In [7]:
from torch.nn.functional import dropout

W = dropout(torch.softmax(Q @ K.transpose(0, 1) / math.sqrt(d), dim=1))
(W @ V).shape

torch.Size([5, 3])

#### Multi-head Self Attention

- Head is a single self attention layer
- Heads are concatenated
- Output linear transform applied


$$
MultiHeadAttention(X, \{W^{0,1...h-1}_{Q}, W^{0,1...h-1}_{K}, W^{0,1...h-1}_{V}\}) = Concat(Head_0, Head_1, ..., Head_{h-1}) W_{O}
$$



In [8]:
N = 5
D = 6
d = 4
o = 3

num_heads = 2

x = torch.rand(N, D)
W_qkv = torch.rand(D, 3 * num_heads * d)
W_o = torch.rand(num_heads * d, o)

Q, K, V = (x @ W_qkv).chunk(3, dim=1)
Q.shape, K.shape, V.shape

(torch.Size([5, 8]), torch.Size([5, 8]), torch.Size([5, 8]))

In [9]:
W = dropout(torch.softmax(Q @ K.transpose(0, 1) / math.sqrt(d), dim=1))
heads = W @ V
heads.shape

torch.Size([5, 8])

In [10]:
(heads @ W_o).shape

torch.Size([5, 3])

#### Multi-head Self Attention in Vision


- Input images to patches : `(B, C, H, W) -> (B, H // p * W // p, D)`

In [11]:
import torch.nn as nn


patch_size = 4
embed_dim = 3 * patch_size * patch_size
conv = nn.Conv2d(3, embed_dim, kernel_size=patch_size, stride=patch_size)

batch_images = torch.rand(4, 3, 32, 32)
patches = conv(batch_images)
patches = patches.flatten(start_dim=2)
patches = patches.transpose(1, 2)
print(batch_images.shape, "->", patches.shape)

torch.Size([4, 3, 32, 32]) -> torch.Size([4, 64, 48])


- multi-head attention ops:

```python
x, (B, N, embed_dim) -> 
            -> Q = X * W_q, (B, N, embed_dim) -> 
            -> K = X * W_k, (B, N, embed_dim) ->
            -> V = X * W_v, (B, N, embed_dim) ->
            -> A = softmax(Q @ K^t * scale), (B, N, N) -> Dropout(A) ->
            -> H = A @ V, (B, N, embed_dim) -> 
            -> y = H @ W_o, (B, N, embed_dim) -> Dropout(y) ->
            -> output

```

In [12]:
class VisionAttention(nn.Module):
    """Vision Multi-Head Attention layer with trainable parameters:
    - W_q, W_k, W_k : embed_dim * embed_dim * 3
    - W_o : embed_dim * embed_dim
    
    .. code-block:: text

        x, (B, N, embed_dim) -> 
                    -> Q = X * W_q, (B, N, embed_dim) -> 
                    -> K = X * W_k, (B, N, embed_dim) ->
                    -> V = X * W_v, (B, N, embed_dim) ->
                    -> A = softmax(Q @ K^t * scale), (B, N, N) -> Dropout(A) ->
                    -> H = A @ V, (B, N, embed_dim) -> 
                    -> y = H @ W_o, (B, N, embed_dim) ->
                    -> output
                    
    https://github.com/google/flax/blob/master/flax/nn/attention.py
    """
    
    def __init__(self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.):
        super().__init__()
        head_dim = embed_dim // num_heads
        self.scale = float(head_dim) ** -0.5
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=qkv_bias)
        self.att_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        Q, K, V = self.qkv(x).chunk(3, dim=-1)
        attention = (Q @ K.transpose(1, 2)) * self.scale
        attention = attention.softmax(dim=-1)
        attention = self.att_drop(attention)
        heads = attention @ V
        output = self.proj(heads)
        return output

In [13]:
num_heads = 6
att = VisionAttention(embed_dim, num_heads=num_heads)

In [14]:
# Compute Q, K, V
Q, K, V = att.qkv(patches).chunk(3, dim=-1)
Q.shape, K.shape, V.shape

(torch.Size([4, 64, 48]), torch.Size([4, 64, 48]), torch.Size([4, 64, 48]))

In [15]:
# Compute heads
head_dim = embed_dim // num_heads
attention = (Q @ K.transpose(1, 2)) * float(head_dim) ** -0.5
attention = attention.softmax(dim=-1)
attention = att.att_drop(attention)
attention.shape

torch.Size([4, 64, 64])

In [16]:
heads = attention @ V
heads.shape

torch.Size([4, 64, 48])

In [18]:
output = att.proj(heads)
output.shape, output.shape == att(patches).shape

(torch.Size([4, 64, 48]), True)

### Transformers

Refs:
- https://d2l.ai/chapter_attention-mechanisms/transformer.html


In NLP Transformers is composed of an encoder and a decoder. 
The input (source) and output (target) sequence embeddings are added with positional 
encoding before being fed into the encoder and the decoder that stack modules based on self-attention.

<div style="background: white;">
<img src="https://d2l.ai/_images/transformer.svg" />
</div>

#### [Vision Transformer for classification](https://arxiv.org/abs/2010.11929)

![im](https://raw.githubusercontent.com/google-research/vision_transformer/master/figure1.png)

Inspired  by  the  Transformer  scaling  successes  in  NLP, authors  experiment  with  applying  a  standard Transformer directly to images, with the fewest possible modifications. To do so, they split an image into patches and provide the sequence of linear embeddings of these patches as an input to a Transformer. Image patches are treated the same way as tokens (words) in an NLP application. They train the model on image classification in supervised fashion.

3 types of models:
- ViT-Base, 86M params (vs ResNet50, 23M params)
- ViT-Large, 307M params
- ViT-Huge, 632M params (vs ResNet152x4 from [Big Transfer](https://arxiv.org/abs/1912.11370), 936M)

ImageNet (1.3M images): models give modest accuracies of a few percentage points below ResNets of comparable size. Transformers **lack some of the inductive biases** inherent to CNNs, such as translation equivariance and locality, and therefore do not generalize well when trained on insufficient amounts of data.

Larger datasets (14M-300M images): large scale training trumps inductive bias. Vision Transformer (ViT) attains "excellent" results when pre-trained at sufficient scale and transferred to tasks with fewer datapoints. 

When pre-trained on the public ImageNet-21k dataset or Google's JFT-300M dataset, ViT approaches or beats state of the art on multiple image recognition benchmarks. 
In particular, the best model reaches the accuracy of 88.55% on ImageNet.

SOTA on ImageNet: https://paperswithcode.com/sota/image-classification-on-imagenet


Self-Supervision (e.g. BERT in NLP) : With self-supervised pre-training, smaller ViT-B/16 model achieves 79.9% accuracy on ImageNet, as ignificant improvement of 2% to training from scratch, but still 4% behind supervised pre-training.


Let's implement main blocks:
- EncoderBlock: layer norms, visual attention and mlp

- Visual Transformer: Patchs and Position Embedding + EncoderBlocks + classification head

In [19]:
class MLP(nn.Module):
    """The positionwise feed-forward network or MLP.    
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, drop_rate=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.gelu = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class EncoderBlock(nn.Module):
    """ViT EncoderBlock
    
    https://github.com/google-research/vision_transformer/blob/9dbeb0269e0ed1b94701c30933222b49189aa33c/vit_jax/models.py#L94
    """
    
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.lnorm1 = nn.LayerNorm(embed_dim)
        self.lnorm2 = nn.LayerNorm(embed_dim)
        self.attention = VisionAttention(
            embed_dim, num_heads=num_heads, attn_drop=attn_drop_rate
        )
        self.dropout = nn.Dropout(drop_rate)
        self.mlp = MLP(embed_dim, int(embed_dim * mlp_ratio), drop_rate=drop_rate)
        
    def forward(self, x):
        y = self.lnorm1(x)
        y = self.attention(y)
        y = self.dropout(y)
        x = x + y
        
        y = self.lnorm2(x)
        y = self.mlp(y)
        return x + y

In [20]:
block = EncoderBlock(embed_dim, num_heads)
out = block(patches)
out.shape

torch.Size([4, 64, 48])

##### Positional encoding

All tokens are processed at once without any spatial relationship.
The idea is to add a learnable parameter to the tokens to retain positional information.

According to the paper, they use 1d positional embedding, considering the inputs as a sequence of patches in the raster order.
Authors have not observed significant performance gains from using more advanced 2D-aware position embeddings.

$$
x = x + pe
$$
where $pe$ is randomly initialized with using Normal distribution.

In [21]:
class VisionTransformer(nn.Module):
    """VisionTransformer model
    
    https://github.com/google-research/vision_transformer/
    """
    
    def __init__(
        self, 
        patch_size=16,
        hidden_size=768,
        input_channels=3,
        input_size=224,
        num_classes=1000,
        num_layers=12,
        num_heads=12,
        mlp_dim=3072,
        drop_rate=0.1, 
        attn_drop_rate=0.0,
    ):
        super().__init__()

        self.patchs_embed = nn.Conv2d(
            input_channels, hidden_size, kernel_size=patch_size, stride=patch_size
        )
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
        
        num_patches = (input_size // patch_size) ** 2
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, hidden_size))
        self.pos_dropout = nn.Dropout(p=drop_rate)  
    
        # Define encoder blocks
        kwargs = {
            "embed_dim": hidden_size,
            "num_heads": num_heads,
            "mlp_ratio": mlp_dim / hidden_size,
            "drop_rate": drop_rate,
            "attn_drop_rate": attn_drop_rate,
        }
        blocks = [EncoderBlock(**kwargs) for _ in range(num_layers)]
        self.blocks = nn.Sequential(*blocks)        
        self.lnorm = nn.LayerNorm(hidden_size)

        self.mlp_head = nn.Linear(hidden_size, num_classes)
    
    def features(self, x):
        patches = self.patchs_embed(x)
        patches = patches.flatten(start_dim=2)
        patches = patches.transpose(1, 2)
        
        batch_size = patches.shape[0]

        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, patches], dim=1)
        x = x + self.pos_embed
        x = self.pos_dropout(x)

        x = self.blocks(x)
        x = self.lnorm(x)
        
        # Return the first token
        return x[:, 0, ...]        
    
    def forward(self, x):
        f = self.features(x)
        y = self.mlp_head(f)        
        return y
    

def vit_b16(num_classes=1000, input_channels=3, input_size=224):
    return VisionTransformer(
        num_classes=num_classes,
        input_channels=input_channels,
        input_size=input_size,
        patch_size=16,
        hidden_size=768,
        num_layers=12,
        num_heads=12,
        mlp_dim=3072,
        drop_rate=0.1, 
        attn_drop_rate=0.0,        
    )


In [22]:
model = vit_b16()

x = torch.rand(4, 3, 224, 224)
model(x).shape

torch.Size([4, 1000])

In [23]:
sum([m.numel() for m in model.parameters()]) * 1e-6

86.540008

In [76]:
def vit_tiny(num_classes=10, input_channels=3, input_size=32):
    return VisionTransformer(
        num_classes=num_classes,
        input_channels=input_channels,
        input_size=input_size,
        patch_size=4,
        hidden_size=512,
        num_layers=4,
        num_heads=6,
        mlp_dim=1024,
        drop_rate=0.1, 
        attn_drop_rate=0.0,
    )

In [78]:
model = vit_tiny()
sum([m.numel() for m in model.parameters()]) * 1e-6

8.470025999999999

### Train a tiny version on CIFAR10

- On TensorBoard.dev: 

In [87]:
!head -6 cifar10/README.md

# Train ViT on CIFAR10 with [PyTorch-Ignite]()

- on 1 or more GPUs
- compute training/validation metrics
- log learning rate, metrics etc
- save the best model weights


### Visualize attention maps

### [Training data-efficient image transformers& distillation through attention](https://arxiv.org/abs/2012.12877)

- https://github.com/facebookresearch/deit