# 2150188401(2) Artificial Intelligence Assignment #2-2<br> Training Vision Transformers (PyTorch)

Copyright (C) Computer Science & Engineering, Soongsil University. This material is for educational uses only. Some contents are based on the material provided by other paper/book authors and may be copyrighted by them. Written by Haneul Pyeon, September 2024.

**For understanding of this work, please carefully look at given PDF file.**

Now, you're going to leave behind your implementations and instead migrate to one of popular deep learning frameworks, **PyTorch**. <br>
In this notebook, you will learn to understand and build the basic components of Vision Tranformer(ViT). Then, you will try to classify images in the FashionMNIST datatset and explore the effects of different components of ViTs.
<br>
There are **2 sections**, and in each section, you need to follow the instructions to complete the skeleton codes and explain them.

**Note**: certain details are missing or ambiguous on purpose, in order to test your knowledge on the related materials. However, if you really feel that something essential is missing and cannot proceed to the next step, then contact the teaching staff with clear description of your problem.

### Submitting your work:
<font color=red>**DO NOT clear the final outputs**</font> so that TAs can grade both your code and results.

### Some helpful tutorials and references for assignment #2-2:
- [1] Pytorch official documentation [[link]](https://pytorch.org/docs/stable/index.html).
- [2] Stanford CS231n lectures [[link]](http://cs231n.stanford.edu/).
- [3] Alexey Dosovitskiy et al., "An Image is Worth 16 x 16 Words: Transformers for Image Recognition at Scale", ICLR 2021 [[pdf]](https://arxiv.org/pdf/2010.11929.pdf).

## 1. Building Vision Transformer
Here, you will build the basic components of Vision Transformer(ViT). <br>

![Vision Transformer](imgs/ViT.png)

Using the explanation and code provided as guidance, <br>
Define each component of ViT. <br>


#### ViT architecture:
* ViT model consists with input patch embedding, positional embeddings, transformer encoder, etc.
* Patch embedding
* Positional embeddings
* Transformer encoder with
    * Attention module
    * MLP module

In [6]:
import torch
import torch.nn as nn

##### Patch Embed

**Initialization**: When you create an instance of the PatchEmbedding class, you specify the image_size, patch_size, and in_channels. image_size is the height and width of the input image, patch_size is the size of each patch, and in_channels is the number of input image channels (e.g., 3 for RGB images).

**Convolutional Projection**: Inside the PatchEmbedding class, a 2D convolutional layer (nn.Conv2d) is used to perform a patch-based projection. This convolutional layer has a kernel size of patch_size, which defines the size of each patch, and a stride of patch_size, which ensures that patches do not overlap. The convolutional layer effectively extracts image patches.

**Reshaping**: After the convolutional projection, the output tensor is reshaped using view. It is transformed from a 4D tensor with dimensions (batch_size, in_channels, H, W) to a 3D tensor with dimensions (batch_size, num_patches, patch_dim). num_patches is the total number of non-overlapping patches in the image, and patch_dim is the number of output channels from the convolutional layer.

In [7]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        # 2D convolution layer
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x # output dimension must be: (batch size, number of patches, embed_dim)

##### Attention

**Initialization**
* dim: The input dimension of the sequence. This is the dimensionality of the queries, keys, and values.
* num_heads: The number of attention heads to use. Multi-head attention allows the model to focus on different parts of the input simultaneously.

**Linear Projections (qkv and proj)**: The qkv linear layer takes the input sequence and projects it into three parts: queries (q), keys (k), and values (v). The output of this layer has a shape of (batch_size, sequence_length, 3 * dim).

**Forward Pass (forward method)**: In the forward pass, the input tensor x is processed through the attention mechanism. Here's what happens:<br>
* The linear projection qkv is applied to x, producing a tensor of shape (batch_size, sequence_length, 3 * dim).|
* This tensor is reshaped to have dimensions (batch_size, sequence_length, 3, num_heads, head_dim). The permute operation rearranges the dimensions to (3, batch_size, num_heads, sequence_length, head_dim), making it suitable for multi-head attention.
* The three parts, q, k, and v, are extracted from the reshaped tensor.
* The attention scores are computed by taking the dot product of queries q and keys k. The result is scaled by self.scale.
* The attention scores are passed through a softmax activation along the last dimension (sequence_length), producing attention weights.
* The weighted sum of values v is computed using the attention weights.
* The result is transposed and reshaped to its original shape, and then passed through the proj linear layer.
* The final output is returned.

In [8]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x # output dimension must be: (batch size, number of patches, embed_dim)

##### MLP

The MLP module must consist of three layers:
* fully conncted layer 1
* activation layer
* fully conncted layer 2

In [9]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features

        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        # Mlp layer
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)


        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x # output dimension must be: (batch size, number of patches, out_features)

##### Transformer Block
The transformer block contains the attention module and MLP module which have residual connections.
Refer to the following image and build the forward pass.

![Transformer Block](imgs/TransformerBlock.png)

In [10]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
                       act_layer=act_layer)

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################

        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x


##### Vision Transformer

Using all the components that you built above, **complete** the vision transformer class.

### torch.cat

Concatenates the given sequence of tensors along the specified dimension. All tensors must either have the same shape (except in the concatenating dimension) or be a 1-D empty tensor with size (0,).

`torch.cat()` can be seen as an inverse operation for `torch.split()` and `torch.chunk()`.

`torch.cat()` can be best understood via examples.

#### Parameters

- **tensors** (sequence of Tensors): any Python sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the concatenating dimension.

- **dim** (int, optional): the dimension over which the tensors are concatenated.

#### Keyword Arguments

- **out** (Tensor, optional): the output tensor.

In [11]:
# example
x = torch.randn(2, 3)
x

tensor([[-0.0181, -0.8599, -0.9947],
        [-1.4668,  0.4443,  0.0269]])

In [12]:
exam1 = torch.cat((x, x, x), 0)
exam1

tensor([[-0.0181, -0.8599, -0.9947],
        [-1.4668,  0.4443,  0.0269],
        [-0.0181, -0.8599, -0.9947],
        [-1.4668,  0.4443,  0.0269],
        [-0.0181, -0.8599, -0.9947],
        [-1.4668,  0.4443,  0.0269]])

In [13]:
exam2 = torch.cat((x, x, x), 1)
exam2

tensor([[-0.0181, -0.8599, -0.9947, -0.0181, -0.8599, -0.9947, -0.0181, -0.8599,
         -0.9947],
        [-1.4668,  0.4443,  0.0269, -1.4668,  0.4443,  0.0269, -1.4668,  0.4443,
          0.0269]])

In [14]:
class VisionTransformer(nn.Module):
    """ Vision Transformer """

    def __init__(self, img_size=28, patch_size=4, in_chans=1, num_classes=10, embed_dim=768, depth=12,
                 num_heads=12, mlp_ratio=4., norm_layer=nn.LayerNorm, ):
        super().__init__()
        self.num_features = self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.depth = depth

        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################
        # similarly to cls_token, define a learnable positional embedding that matches the patchified input token size.
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################

        self.blocks = nn.ModuleList([
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,  norm_layer=norm_layer)
            for i in range(depth)])
        self.norm = norm_layer(embed_dim)

        # Classifier head
        self.head = nn.Linear(
            embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        ##############################################################################
        #                           IMPLEMENT YOUR CODE                              #
        ##############################################################################
        B = x.shape[0]

        # Patch Embedding
        x = self.patch_embed(x)

        # Concatenate class tokens to patch embedding
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embedding to patches
        x = x + self.pos_embed

        # Forward through encoder blocks
        for blk in self.blocks:
            x = blk(x)

        # Normalize the output
        x = self.norm(x)

        # Use class token for classification
        cls_token_output = x[:, 0]

        # Classifier head
        x = self.head(cls_token_output)

        ##############################################################################
        #                              END YOUR CODE                                 #
        ##############################################################################
        return x

## 2. Training a small ViT model on FashionMNIST dataset.

Define and Train a vision transformer on FashionMNIST dataset. **(You must reach above 85% for full points.)** <br>
Train with at least 5 different hyperparameter settings varying the following ViT hyperparameters.
Report the setting for the best performance.

#### ViT hyperparameters:
* patch_size
* embed_dim
* depth
* num_heads
* mlp_ratio
* etc.


In [15]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import FashionMNIST

In [23]:
def Train():
    ##############################################################################
    #                           IMPLEMENT YOUR CODE                              #
    ##############################################################################

    patch_size=4
    embed_dim=32
    depth=6
    num_heads=16
    mlp_ratio=4




# 85.3
    # patch_size=7
    # embed_dim=32
    # depth=6
    # num_heads=4 # make sure embed_dim is divisible by num_heads!
    # mlp_ratio=4


    # patch_size=7
    # embed_dim=32
    # depth=6
    # num_heads=16 # make sure embed_dim is divisible by num_heads!
    # mlp_ratio=4


    # patch_size=7
    # embed_dim=16
    # depth=6
    # num_heads=4 # make sure embed_dim is divisible by num_heads!
    # mlp_ratio=4

    ##############################################################################
    #                              END YOUR CODE                                 #
    ##############################################################################

    # Loading data
    transform = ToTensor()

    train_set = FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_set = FashionMNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_set, shuffle=True, batch_size=128)
    test_loader = DataLoader(test_set, shuffle=False, batch_size=128)

    # Defining model and training options
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device: ", device, f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "")

    model = VisionTransformer(patch_size=patch_size, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio).to(device)
    model_path = './vit.pth'
    N_EPOCHS = 5
    LR = 0.005

    # Training loop
    optimizer = Adam(model.parameters(), lr=LR)
    criterion = CrossEntropyLoss()
    for epoch in trange(N_EPOCHS, desc="Training"):
        train_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)

            train_loss += loss.detach().cpu().item() / len(train_loader)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")

    # Test loop
    with torch.no_grad():
        correct, total = 0, 0
        test_loss = 0.0
        for batch in tqdm(test_loader, desc="Testing"):
            x, y = batch
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, y)
            test_loss += loss.detach().cpu().item() / len(test_loader)

            correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
            total += len(x)
        print(f"Test loss: {test_loss:.2f}")
        print(f"Test accuracy: {correct / total * 100:.2f}%")

    torch.save(model.state_dict(), model_path)
    print('Saved Trained Model.')

Train()

Using device:  cuda (Tesla T4)


Training:   0%|          | 0/5 [00:00<?, ?it/s]
Epoch 1 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 1 in training:   1%|          | 3/469 [00:00<00:16, 28.72it/s][A
Epoch 1 in training:   1%|▏         | 6/469 [00:00<00:16, 28.42it/s][A
Epoch 1 in training:   2%|▏         | 10/469 [00:00<00:15, 29.69it/s][A
Epoch 1 in training:   3%|▎         | 14/469 [00:00<00:14, 31.40it/s][A
Epoch 1 in training:   4%|▍         | 18/469 [00:00<00:14, 32.01it/s][A
Epoch 1 in training:   5%|▍         | 22/469 [00:00<00:13, 32.50it/s][A
Epoch 1 in training:   6%|▌         | 26/469 [00:00<00:13, 33.08it/s][A
Epoch 1 in training:   6%|▋         | 30/469 [00:00<00:13, 33.42it/s][A
Epoch 1 in training:   7%|▋         | 34/469 [00:01<00:13, 31.43it/s][A
Epoch 1 in training:   8%|▊         | 38/469 [00:01<00:14, 30.73it/s][A
Epoch 1 in training:   9%|▉         | 42/469 [00:01<00:13, 31.49it/s][A
Epoch 1 in training:  10%|▉         | 46/469 [00:01<00:13, 32.01it/s][A
Epoch 1 in tra

Epoch 1/5 loss: 0.81



Epoch 2 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 2 in training:   0%|          | 2/469 [00:00<00:25, 18.13it/s][A
Epoch 2 in training:   1%|          | 5/469 [00:00<00:23, 20.14it/s][A
Epoch 2 in training:   1%|▏         | 7/469 [00:00<00:23, 19.89it/s][A
Epoch 2 in training:   2%|▏         | 10/469 [00:00<00:20, 22.23it/s][A
Epoch 2 in training:   3%|▎         | 14/469 [00:00<00:17, 26.23it/s][A
Epoch 2 in training:   4%|▍         | 18/469 [00:00<00:15, 28.78it/s][A
Epoch 2 in training:   5%|▍         | 22/469 [00:00<00:14, 30.46it/s][A
Epoch 2 in training:   6%|▌         | 26/469 [00:00<00:14, 30.96it/s][A
Epoch 2 in training:   6%|▋         | 30/469 [00:01<00:13, 31.95it/s][A
Epoch 2 in training:   7%|▋         | 34/469 [00:01<00:13, 31.71it/s][A
Epoch 2 in training:   8%|▊         | 38/469 [00:01<00:13, 32.25it/s][A
Epoch 2 in training:   9%|▉         | 42/469 [00:01<00:14, 30.34it/s][A
Epoch 2 in training:  10%|▉         | 46/469 [00:01<00:13, 31.

Epoch 2/5 loss: 0.47



Epoch 3 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 3 in training:   0%|          | 2/469 [00:00<00:31, 14.88it/s][A
Epoch 3 in training:   1%|▏         | 6/469 [00:00<00:18, 24.80it/s][A
Epoch 3 in training:   2%|▏         | 10/469 [00:00<00:16, 28.66it/s][A
Epoch 3 in training:   3%|▎         | 14/469 [00:00<00:15, 29.11it/s][A
Epoch 3 in training:   4%|▎         | 17/469 [00:00<00:15, 29.26it/s][A
Epoch 3 in training:   4%|▍         | 21/469 [00:00<00:14, 30.87it/s][A
Epoch 3 in training:   5%|▌         | 25/469 [00:00<00:14, 29.84it/s][A
Epoch 3 in training:   6%|▌         | 29/469 [00:00<00:14, 30.74it/s][A
Epoch 3 in training:   7%|▋         | 33/469 [00:01<00:13, 31.59it/s][A
Epoch 3 in training:   8%|▊         | 37/469 [00:01<00:13, 32.21it/s][A
Epoch 3 in training:   9%|▊         | 41/469 [00:01<00:13, 32.54it/s][A
Epoch 3 in training:  10%|▉         | 45/469 [00:01<00:15, 27.64it/s][A
Epoch 3 in training:  10%|█         | 48/469 [00:01<00:15, 26

Epoch 3/5 loss: 0.42



Epoch 4 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 4 in training:   1%|          | 4/469 [00:00<00:14, 33.12it/s][A
Epoch 4 in training:   2%|▏         | 8/469 [00:00<00:14, 31.67it/s][A
Epoch 4 in training:   3%|▎         | 12/469 [00:00<00:14, 32.42it/s][A
Epoch 4 in training:   3%|▎         | 16/469 [00:00<00:14, 30.25it/s][A
Epoch 4 in training:   4%|▍         | 20/469 [00:00<00:14, 30.86it/s][A
Epoch 4 in training:   5%|▌         | 24/469 [00:00<00:14, 31.61it/s][A
Epoch 4 in training:   6%|▌         | 28/469 [00:00<00:13, 32.16it/s][A
Epoch 4 in training:   7%|▋         | 32/469 [00:01<00:13, 32.14it/s][A
Epoch 4 in training:   8%|▊         | 36/469 [00:01<00:13, 32.83it/s][A
Epoch 4 in training:   9%|▊         | 40/469 [00:01<00:13, 32.89it/s][A
Epoch 4 in training:   9%|▉         | 44/469 [00:01<00:12, 32.81it/s][A
Epoch 4 in training:  10%|█         | 48/469 [00:01<00:13, 31.25it/s][A
Epoch 4 in training:  11%|█         | 52/469 [00:01<00:12, 32

Epoch 4/5 loss: 0.39



Epoch 5 in training:   0%|          | 0/469 [00:00<?, ?it/s][A
Epoch 5 in training:   1%|          | 3/469 [00:00<00:21, 21.42it/s][A
Epoch 5 in training:   1%|▏         | 6/469 [00:00<00:22, 20.80it/s][A
Epoch 5 in training:   2%|▏         | 9/469 [00:00<00:21, 21.42it/s][A
Epoch 5 in training:   3%|▎         | 12/469 [00:00<00:20, 22.56it/s][A
Epoch 5 in training:   3%|▎         | 15/469 [00:00<00:20, 22.10it/s][A
Epoch 5 in training:   4%|▍         | 18/469 [00:00<00:19, 23.21it/s][A
Epoch 5 in training:   4%|▍         | 21/469 [00:00<00:19, 22.50it/s][A
Epoch 5 in training:   5%|▌         | 24/469 [00:01<00:20, 22.12it/s][A
Epoch 5 in training:   6%|▌         | 27/469 [00:01<00:20, 21.50it/s][A
Epoch 5 in training:   6%|▋         | 30/469 [00:01<00:20, 21.89it/s][A
Epoch 5 in training:   7%|▋         | 33/469 [00:01<00:20, 20.92it/s][A
Epoch 5 in training:   8%|▊         | 36/469 [00:01<00:20, 21.34it/s][A
Epoch 5 in training:   9%|▊         | 40/469 [00:01<00:17, 24.

Epoch 5/5 loss: 0.37


Testing: 100%|██████████| 79/79 [00:01<00:00, 48.55it/s]

Test loss: 0.40
Test accuracy: 85.18%
Saved Trained Model.





### Describe what you did and discovered here
In this cell you should write all the settings tried and performances you obtained. Report what you did and what you discovered from the trials.
You can write in Korean and English

    patch_size=4
    embed_dim=32
    depth=6
    num_heads=8
    mlp_ratio=4
Test accuracy: 84.32%

    patch_size=4
    embed_dim=16
    depth=6
    num_heads=8
    mlp_ratio=4
Test accuracy: 83.90% -> embed_dim을 줄였을 때 정확도가 감소하였다.

    patch_size=4
    embed_dim=16
    depth=6
    num_heads=4
    mlp_ratio=4
Test accuracy: 83.41% -> num_heads를 줄였을 때 정확도가 감소하였다. 모델의 표현력이 줄어 정확도가 줄어든 것으로 보인다.

    patch_size=4
    embed_dim=32
    depth=6
    num_heads=16
    mlp_ratio=4
Test accuracy: 85.18% -> num_heads를 증가시켰더니 정확도 증가했다. 패치에 대해 더 많은 정보를 가질 수 있어, 모델의 표현력이 증가해 성능이 향상된 것으로 보인다.