# Deep Learning course - LAB 9

## An explainability-first implementation of the Vision Transformer

This lab will mainly follow the slides from the lecture on the Vision Transformer (ViT).

Please refer to the slides for the methodological explanations.

We will be constructing the ViT bottom-up, i.e. from the input embedding to the output.

_**Disclaimer**: this lab is just for explanatory purposes. If you intend on using a ViT for your work, I suggest using one of the many plug-and-play versions you can find on the web (e.g. the one in `timm`), which are probably more efficient than this implementation. Here, I sacrifice performance for clarity._

In [None]:
import torch
from torch import nn
from torchsummary import summary
from torchvision import transforms as T
from scripts import mnistm, train

## 1a. Patch + vectorize input

The input is first subdivided into patches and each patch is *unrolled* into a 1D vector.

Let us implement a generic torchvision-style transform which we can pass to a `Dataset`'s `transform` attribute.

In [None]:
class ToVecPatch():
    def __init__(self, patch_size, axis_channel=2):
        '''
        axis_channel is the axis (dim) in which the channels are located.
        An image should be h x w x c -> 2
        An image converted with ToTensor() should be c x h x w -> 0
        '''
        assert axis_channel in (0,2), f"Method supports only axis_channel 0 or 2."
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.axis_channel = axis_channel

    
    def __call__(self, sample):
        '''
        sample is a torch.Tensor
        '''
        if self.axis_channel == 0:
            sample = sample.permute(1,2,0)
        
        patch_vert = torch.cat(sample.split(self.patch_size[1], dim=1))
        patch_horiz = torch.stack(patch_vert.split(self.patch_size[0], dim=0))
        vec_patches = torch.flatten(patch_horiz, start_dim=1)
        return vec_patches

Let's see it in action on a small 4x4 grayscale image

In [None]:
x = torch.Tensor([[1,2,3,4],[4,5,6,7],[7,8,9,0],[10,7,3,88]])
patch_size = 2
x

In [None]:
patch_vert = torch.cat(x.split(patch_size, dim=1))
patch_vert

In [None]:
patch_horiz = torch.stack(patch_vert.split(patch_size, dim=0))
patch_horiz

In [None]:
vec_patches = torch.flatten(patch_horiz, start_dim=1)
vec_patches

In [None]:
P = ToVecPatch(2)
P(sample=x)

# 1b. Input embedding

Now we need to take care of the input embedding:
* we have an input $I$ with shape $N \times P^2\cdot c$, where:
    * $N$ is the number of patches
    * $P$ is the patch size
    * $c$ is the channel size (1 in the example above)
* we need to linearly project $I$ into $z_0$, belonging in the $N \times D$ space, where $D$ is (hopefully) smaller than $P^2\cdot c$
* we also need to prepend a learnable `<class>` token to $z_0$
* and we need to sum the **positional embedding/encoding** to it

Also, in a `Module`-like class, we need to take into account that the input will be 3-dimensional ($B \times N \times P^2\cdot c$, $B$ being the batch size)

In [None]:
class EmbedInput(nn.Module):
    def __init__(self, num_patches, patch_dim, latent_dim, bias=False, dropout_p=0.0):
        super().__init__()
        self.embed = nn.Linear(patch_dim, latent_dim, bias=bias) # this represents the matrix E
        self.dropout = nn.Dropout(dropout_p)
        # the next params are the same independent of the batch size
        self.class_token = nn.Parameter(torch.randn(1, 1, latent_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, latent_dim))
    
    def forward(self, X):
        z = self.embed(X)
        z = self.dropout(z)
        z = torch.cat((self.class_token.expand(z.shape[0], *self.class_token.shape[1:]), z), dim=1)
        z += self.pos_embedding
        return z

Let's try it on real data. We use MNISTM since it's an easy dataset, but the images have 3 channels unlike MNIST.

Create the `Dataset` and pull one batch of data.

In [None]:
transforms = T.Compose([
    T.ToTensor(),
    T.Normalize([0.4639, 0.4676, 0.4199], [0.2534, 0.2380, 0.2618]),
    ToVecPatch(7, axis_channel=0), # image of size 28*28, patches of size 7
])

mnistm_train = mnistm.MNISTM(root="datasets/MNISTM", download=True, transform=transforms)
dataloader = torch.utils.data.DataLoader(mnistm_train, 128, shuffle=False, num_workers=8)
batch, _ = next(iter(dataloader))
batch.shape

embed the batch

In [None]:
embed = EmbedInput(num_patches=16, patch_dim=147, latent_dim=24)
embedded = embed(batch)
embedded.shape

## 2. Attention

* We have an embedded input $z_0$ of shape $B\times (N+1)\times D$

* We need to:
    * get $Q, K, V \in \mathbb{R}^{B\times (N+1)\times d}$ through linear projection from $z_0$
    * obtain $A = \text{softmax}(QK^\top/\sqrt{d})$
    * get $S = AV$

all this for each head $h\in\{1,\dots H\}$

In [None]:
class MultiheadedSelfAttention(nn.Module):
    def __init__(self, num_heads, input_dim, attention_dim, bias=False, dropout_p=0.0):
        '''
        input_dim -> D
        attention_dim -> d
        '''
        super().__init__()
        self.num_heads = num_heads
        self.attention_dim = attention_dim
        self.input_dim = input_dim
        self.u_qkv = nn.Linear(input_dim, attention_dim * num_heads * 3, bias=bias)
        self.u_msa = nn.Linear(attention_dim * num_heads, input_dim, bias=bias)
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, z):
        QKV = self.u_qkv(z).chunk(3, dim=-1)
        separate_heads = lambda tensor: tensor.reshape(*tensor.shape[:2], self.num_heads, self.attention_dim).permute(0,2,1,3)
        Q, K, V = [separate_heads(t) for t in QKV]
        '''
        Why all that mess?
            Out of the linear projection we get a tensor of shape B x (N+1) x 3Hd
            We separate this tensor into three chunks of shape B x (N+1) x Hd
            We now need to "enucleate" the head from the third dim (->reshape)
            Then, for simplicity, we shift the head to the second dim (->permute)
            Shape: B x H x (N+1) x d
            Now, for each head, we need to do the dot product between Q and K
            This can be done in an elegant way using the einstein notation (einsum)
        '''
        A = torch.einsum("b h n d, b h m d -> b h n m", Q, K) / (self.attention_dim ** .5)
        '''
        We can use only small letters (no capitals)
        b is batch size, h is head size, d is attention_dim
        n and m are the no. of patches for Q and K respectively
        Despite being =, we must name them differently so torch knows
        how to carry out the product
        '''
        A = torch.nn.functional.softmax(A, dim=-1)
        S = torch.einsum("b h n m, b h m d -> b h n d", A, V)
        # undo separate_heads
        S = S.permute(0, 2, 1, 3)
        S = S.reshape(*S.shape[:2], S.shape[2]*S.shape[3])
        S = self.u_msa(S)
        return self.dropout(S)
        

In [None]:
msa = MultiheadedSelfAttention(num_heads=6, input_dim=24, attention_dim=7)
z_prime = msa(embedded)
z_prime.shape

## 3. MLP layer

Very easy, let's do it by ourselves...

$(B\times (N+1) \times D) \rightarrow (B\times (N+1)\times m) \rightarrow (B\times (N+1)\times D)$ 

**add dropout wherever it's possible**

**use `GeLU` non-linearity**

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, bias=True, dropout_p=0.0):
        super().__init__()
        pass

    def forward(self, X):
        pass


In [None]:
mlp = MLP(input_dim=24, hidden_dim=512)
z_1 = mlp(z_prime)
z_1.shape

## 4. The MSA Layer

We need to put together 2. and 3.

![](img/msa_layer.jpg)

In [None]:
class MSALayer(nn.Module):
    def __init__(self, embed_dim, num_heads, attention_dim, mlp_dim, bias_msa=False, bias_mlp=True):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.msa = MultiheadedSelfAttention(num_heads, embed_dim, attention_dim, bias=bias_msa)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_dim, bias=bias_mlp)
    
    def forward(self, X):
        # DIY
        pass

In [None]:
msal = MSALayer(embed_dim=24, num_heads=6, attention_dim=7, mlp_dim=512)
z = msal(embedded)
z.shape

## 5. The final MLP head

Easy...

$(B \times D) \rightarrow (B \times \kappa)$

before, we use layernorm

In [None]:
class MLPHead(nn.Module):
    def __init__(self, input_dim, num_classes, bias=True):
        super().__init__()
        self.layernorm = nn.LayerNorm(input_dim)
        self.fc = nn.Linear(input_dim, num_classes, bias=bias)
    
    def forward(self, X):
        return self.fc(self.layernorm(X))

## Let's put all of our pieces together

In [None]:
class ViT(nn.Module):
    def __init__(
        self,
        num_patches,
        patch_dim,
        embed_dim,
        num_msa_layers,
        num_heads,
        attention_dim,
        mlp_dim,
        num_classes,
        bias_embed=False,
        bias_msa=False,
        bias_mlp_att=True,
        bias_mlp_head=True
        # no dropout for simplicity
    ):
        super().__init__()
        self.input_embedder = EmbedInput(num_patches, patch_dim, embed_dim, bias=bias_embed)
        self.msa = nn.Sequential(
            *([MSALayer(embed_dim, num_heads, attention_dim, mlp_dim, bias_msa=bias_msa, bias_mlp=bias_mlp_att)] * num_msa_layers)
        )
        self.head = MLPHead(embed_dim, num_classes, bias=bias_mlp_head)
    
    def forward(self, X):
        '''
        X is already a tensor B images decomposed into vectorized patches
        '''
        out = self.input_embedder(X)
        out = self.msa(out)
        out = out[:,0] # keep only the "context token"
        return self.head(out)


Let's first build a small ViT with 4 MSA layers and the specs from our tries, then check whether the dimension of the output is as expected:

In [None]:
vit = ViT(num_patches=16, patch_dim=147, embed_dim=24, num_msa_layers=4, num_heads=6, attention_dim=7, mlp_dim=512, num_classes=10)
y = vit(batch)
y.shape

Let's try a quick 10-epochs training with default Adam. We shouldn't expect great performances as our model is very small and the transformer is thought for training on large-scale datasets.

In [None]:
optimizer = torch.optim.Adam(vit.parameters())
train.train_model(vit, dataloader, nn.CrossEntropyLoss(), optimizer, 10)

### Instantiate a ViT-Base model

![](img/vit_models.jpg)

Build it for images of size 224x224 and patches of size 16x16 (→196 patches).

We comply with the paper and set $d=D/H=768/12=64$

In [None]:
vit = ViT(num_patches=196, patch_dim=16*16*3, embed_dim=768, num_msa_layers=12, num_heads=12, attention_dim=64, mlp_dim=3072, num_classes=1000)
vit

In [None]:
_ = summary(vit)

Check the no. of parameters (note that the summary above doesn't recognize the sequences of `MSALayer`s for some reasons...)

In [None]:
(1536 * 2 + 2359296 + 4722432) * 12 + 1536 + 769000 + 196608

This was just a demo showcasing one of the possible ways we can construct a structure like the Visual Transformers.

If you need to use it, I suggest using pre-built stuff, like the one contained in `timm`.

You'll notice that existing implementations tend to make more use of the `einops` library, which introduces some methods, ubiquitous to PyTorch and NumPy, for transposing (permuting) a tensor, repeating given dims... Check out the [docs](https://einops.rocks/1-einops-basics/) if you're interested.