# M2177.004300 002 Deep Learning Assignment #2<br> Part 2: Vision Transformer


Copyright (C) Data Science & AI Laboratory, Seoul National University. This material is for educational uses only. Some contents are based on the material provided by other paper/book authors and may be copyrighted by them. Written by DongHyeok Lee, October 2024


**For understanding of this work, please carefully look at given PDF file.**

In this notebook, you will learn,

1. Implementing SwinTransformer:

- SwinTransformer is a variant of the Vision Transformer that uses hierarchical feature representation and shifted window-based self-attention mechanisms.
- You will implement key components of the SwinTransformer, including window partitioning, shifted window attention mechanisms, and hierarchical feature extraction processes.
- This exercise will provide you with a deep understanding of the inner workings of state-of-the-art vision models.

2. Comparing ViT and Swin-ViT with CIFAR-10:

- You will compare the performance of ViT and Swin-ViT models using the CIFAR-10 dataset.
- You'll utilize the Hugging Face library to easily implement and train both models.
- You'll compare the models in various aspects such as model architecture, training speed, and accuracy, to understand the strengths and weaknesses of each model.
- This comparative analysis will help you develop the ability to choose appropriate models for various vision tasks.

**Note**: certain details are missing or ambiguous on purpose, in order to test your knowledge on the related materials. However, if you really feel that something essential is missing and cannot proceed to the next step, then contact the teaching staff with clear description of your problem.

### Submitting your work:

<font color=red>**DO NOT clear the final outputs**</font> so that TAs can grade both your code and results.


This section of the assignment focuses on utilizing the Hugging Face library to train and compare Vision Transformer (ViT) and Swin Transformer (Swin-ViT) models. The primary objectives are:

1. Performance and Efficiency Comparison:

   - Compare the performance of ViT and Swin-ViT on the same task (e.g., image classification on CIFAR-10).
   - Analyze the differences in terms of:
     a) Accuracy: Evaluate which model achieves higher classification accuracy.
     b) Parameters: Compare the number of trainable parameters in each model.
     c) FLOPs (Floating Point Operations): Assess the computational efficiency of each model.
   - Explain the reasons behind the observed differences, considering the architectural distinctions between ViT and Swin-ViT.

2. Hyperparameter Tuning for Optimal Performance:
   - Experiment with various hyperparameters for both ViT and Swin-ViT models to achieve the best possible performance.
   - Parameters to tune may include:
     a) Learning rate and learning rate schedule
     b) Batch size
     c) Number of attention heads
     d) Embedding dimensions
     e) Number of layers
     f) Dropout rates
   - Document the impact of different hyperparameter configurations on model performance.
   - Identify the optimal set of hyperparameters for each model that yields the highest accuracy on the given task.

By completing this section, you will gain hands-on experience in:

- Implementing and training state-of-the-art vision models using the Hugging Face library.
- Conducting comparative analysis of different transformer architectures for computer vision tasks.
- Practicing hyperparameter tuning to optimize model performance.
- Critically analyzing the trade-offs between model complexity, computational efficiency, and accuracy.

This exercise will deepen your understanding of modern vision transformers and equip you with practical skills in model selection and optimization for real-world computer vision applications.


### 2-2-1 Implementing SwinTransformer


<img src="./images/swin_transformer_architecture.png" alt="Image description" width="500"/>


Completing a generally functional Swin Transformer module can be challenging.  
<font color=red>**Therefore, in this problem, we aim to simplify the conditions. The conditions are as follows:**<font>


<img src="./images/window-attention.png" alt="Image description" width="500"/>


#### Procedure:

1. Apply Self-Window Attention:

- Pass the input tensor X of shape $(B \times H \times W \times C)$ through a self-window attention mechanism with $n$ heads and a window size of $w$

2. Concatenate Attention Outputs and Apply MLP:

- Concatenate the outputs from the $n$ attention heads along the channel dimension.
- Pass the concatenated result through an MLP to adjust the channel size back to the original dimension $C$

3. Residual Connection:

- Add the original input $X$ to the output from the previous step to form a residual connection.

#### Details:

a. Self-Attention within Square Windows:

- The self-attention mechanism operates within square windows of size $ w \times w $

b. Window Shifting:

- Apply window shifts that are equal in size to the window dimensions, effectively sliding the windows across the input tensor.

c. Zero Padding:

- If necessary, add zero-padding to the **right** and **bottom** edges of the input tensor to ensure it can be evenly divided into windows.

d. Calculation of $q$, $k$, and $v$:

- Compute the query ($q$), key ($k$), and value ($v$) tensors using the provided **IdentityMLP**.
- Do not include additional biases, normalization layers, or dropout during this computation.

e. Channel-Wise Concatenation and Channel Size Adjustment:

- Concatenate the outputs from the $n$ self-attention operations along the channel axis.
- Pass the concatenated tensor through the **HeaderConcatMLP** to adjust the channel size to match that of the input $X$.


In [None]:
import torch
import torch.nn as nn

from typing import Tuple


# use identity mlp when calculating q, k, v
class IdentityMLP(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.mlp = nn.Linear(dim, dim)
        self.initialize_weights_ones()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

    def initialize_weights_ones(self):
        nn.init.ones_(self.mlp.weight)
        nn.init.ones_(self.mlp.bias)


# use HeaderConcatMLP when merging heads
class HeaderConcatMLP(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.mlp = nn.Linear(in_dim, out_dim)
        self.initialize_weights_ones()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)

    def initialize_weights_ones(self):
        nn.init.ones_(self.mlp.weight)
        nn.init.ones_(self.mlp.bias)


def cal_window_transformer_block(
    x,
    #
    window_size: Tuple[int, int],
    num_heads: int,
):
    B, H, W, C = x.shape
    assert num_heads > 0, "num_heads must be greater than 0"
    assert (
        window_size[0] > 0 and window_size[1] < H and window_size[1] < W
    ), "window_size must be less than image size"

    ##############################################################################
    #                          IMPLEMENT YOUR CODE                               #
    ##############################################################################

    ##############################################################################
    #                          END OF YOUR CODE                                  #
    ##############################################################################

    output: torch.Tensor = x
    B_, H_, W_, C_ = output.shape
    assert (B_, H_, W_, C_) == (
        B,
        H,
        W,
        C,
    ), "output shape should be same as input shape"
    return output

<font color='red'>Important!</font>  
Write the final result to `model_checkpoints/cal_window_transformer_block.py` and submit it.  
Errors resulting from modifications to any part of the script other than the function implementation sections will be considered as a failure to submit.


### 2-2-2 Vision Transformer for image classification


In this assignment, it aims to train ViT-family models for image classification using the CIFAR-10 dataset


#### load cifar 10 dataset


In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
import torchvision
import torch


def load_cifar10(data_dir, image_size: tuple[int, int] = (224, 224)):
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5],
            ),
        ]
    )

    train_dataset = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=transform
    )
    return train_dataset, test_dataset


def get_dataloader(train_dataset, test_dataset, batch_size, num_workers):
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    return train_loader, test_loader

In [None]:
from PIL import Image
import numpy as np
import cv2
import torch


def visualize_tensor_image(
    tensor_image: torch.Tensor,
    mean: list[float] = [0.5, 0.5, 0.5],
    std: list[float] = [0.5, 0.5, 0.5],
    size: tuple[int, int] = (32, 32),
) -> Image:
    assert len(tensor_image.shape) == 3, "Tensor must have 3 dimensions"
    tensor_image = tensor_image.permute(1, 2, 0)
    tensor_image = tensor_image.cpu().numpy()
    tensor_image = (tensor_image * std) + mean
    tensor_image = (tensor_image * 255).clip(0, 255).astype(np.uint8)
    tensor_image = cv2.resize(tensor_image, size)
    return Image.fromarray(tensor_image)

In [None]:
image_size = (32, 32)
train_dataset, test_dataset = load_cifar10("./data", image_size)
class_names = train_dataset.classes

In [None]:
from matplotlib import pyplot as plt
from random import randint

plt.figure(figsize=(10, 10))
for i in range(16):
    rv_index = randint(0, len(train_dataset))
    rv_image, rv_index = train_dataset[rv_index]
    rv_label = class_names[rv_index]
    plt.subplot(4, 4, i + 1)
    plt.imshow(
        visualize_tensor_image(
            rv_image, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], size=image_size
        )
    )
    plt.title(f"{rv_label} ({rv_index})")
    plt.axis("off")
plt.show()

#### load dataloader and check validation


In [6]:
batch_size = 32
num_classes = 10

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

# check data validation
for images, labels in train_loader:
    assert images.shape == (
        batch_size,
        3,
        image_size[0],
        image_size[1],
    ), "images shape is not correct"
    assert labels.shape == (batch_size,), "labels shape is not correct"
    assert labels.max() < num_classes, "labels are out of range"

for images, labels in test_loader:
    assert images.shape == (
        batch_size,
        3,
        image_size[0],
        image_size[1],
    ), "images shape is not correct"
    assert labels.shape == (batch_size,), "labels shape is not correct"
    assert labels.max() < num_classes, "labels are out of range"

### Build and train ViT and Swin ViT with HF


Hugging Face provides pre-implemented versions of ViT and SwinViT, making it convenient to create and use these models. Below are some simple examples, and you can find more detailed information through the following links:

1. Vision Transformer (ViT):

   - Hugging Face documentation: [ViT Model](https://huggingface.co/docs/transformers/model_doc/vit)
   - Example usage:

     ```python
     from transformers import ViTModel, ViTConfig

     # Creating a ViT model
     configuration = ViTConfig()
     model = ViTModel(configuration)

     # For a pre-trained model
     model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
     ```

2. Swin Transformer:

   - Hugging Face documentation: [Swin Transformer Model](https://huggingface.co/docs/transformers/model_doc/swin)
   - Example usage:

     ```python
     from transformers import SwinModel, SwinConfig

     # Creating a Swin Transformer model
     configuration = SwinConfig()
     model = SwinModel(configuration)

     # For a pre-trained model
     model = SwinModel.from_pretrained("microsoft/swin-tiny-patch4-window7-224")
     ```


In [7]:
# if torchinfo is not installed, install it
!pip install torchinfo
from torchinfo import summary

In [None]:
# https://huggingface.co/docs/transformers/model_doc/swin#transformers.SwinConfig
from transformers import SwinConfig
from transformers import SwinForImageClassification

swin_config = SwinConfig(
    image_size=image_size[0],
    num_labels=num_classes,
    embed_dim=32,
    depths=[2],
    num_heads=[4],
    window_size=5,
    drop_path_rate=0.1,
)

swin_model = SwinForImageClassification(swin_config)
summary(swin_model, input_size=(1, 3, 224, 224))

In [None]:
from transformers import ViTConfig
from transformers import ViTForImageClassification

vit_config = ViTConfig(
    image_size=image_size[0],
    num_labels=num_classes,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    patch_size=16,
    intermediate_size=3072,
    hidden_act="gelu",
    classifier_dropout=0.1,
)

vit_model = ViTForImageClassification(vit_config)


summary(vit_model, input_size=(1, 3, image_size[0], image_size[1]))

#### For each model (ViT and Swin-ViT), create and submit the best version in terms of performance.

Specifically:

1. Performance Optimization:

   - Experiment with various hyperparameters and configurations for both ViT and Swin-ViT.
   - Aim to achieve the highest possible accuracy on the given task (e.g., CIFAR-10 classification).

2. Comparative Analysis:

   - Provide a short comparison between your ViT and Swin-ViT models.
   - Discuss the strengths and weaknesses of each in terms of performance and computational efficiency at **model_checkpoints/assignment2-2-2/report.md**  
     !!Tip. Write it briefly. Length and content are not part of the grading score.!!

3. Best Model Selection:
   - **Submit the best model from either ViT or Swin-ViT.**  
     <font color='red'>The scores will be assigned in order based on the highest score, and a perfect score will be given for accuracy of 73% or above</font>


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import Optimizer

from tqdm import tqdm


def train(
    model: nn.Module,
    loader: DataLoader,
    optimizer: Optimizer,
    criterion: nn.Module,
    device: torch.device,
):
    model.train()
    model.to(device)
    running_loss = 0.0
    total = 0
    correct = 0
    pbar = tqdm(loader, desc="Training")

    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        logits = outputs.logits
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(logits, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

        pbar.set_postfix(loss=loss.item())
        pbar.update(1)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


def evaluate(
    model: nn.Module,
    loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
):
    model.eval()
    model.to(device)
    running_loss = 0.0
    total = 0
    correct = 0
    pbar = tqdm(loader, desc="Evaluating")

    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            logits = outputs.logits
            loss = criterion(logits, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(logits, 1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

            pbar.set_postfix(loss=loss.item())
            pbar.update(1)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [10]:
import torch
import torch.optim as optim
import torch.nn as nn

from transformers import ViTConfig
from transformers import ViTForImageClassification

# cpu or gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# or if you have MPS
if torch.backends.mps.is_available():
    device = torch.device("mps")

In [26]:
NUM_CLASSES = 10  # DO NOT CHANGE

model_name = "vit"  # <---- will be used for saving model
batch_size = 32  # <---- feel free to change
optimizer = optim.Adam  # <---- feel free to change
optimizer_kwargs = {  # <---- feel free to change
    "lr": 0.001,
}
epochs = 10  # <---- feel free to change
seed = 42  # <---- feel free to change

# <---- feel free to change
model_config = ViTConfig(
    image_size=image_size[0],
    num_labels=NUM_CLASSES,
    hidden_size=768,
    num_hidden_layers=1,
    num_attention_heads=1,
    patch_size=16,
    intermediate_size=3072,
    hidden_act="gelu",
    classifier_dropout=0.1,
)
model = ViTForImageClassification(model_config)

# or use Swin-ViT
#
# model_config = SwinConfig(
#     patch_size=4,
#     image_size=image_size[0],
#     num_labels=num_classes,
#     embed_dim=32,
#     depths=[2],
#     num_heads=[4],
#     window_size=5,
#     drop_path_rate=0.1,
# )
# model = SwinForImageClassification(model_config)

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

optimizer = optimizer(model.parameters(), **optimizer_kwargs)
criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, drop_last=True
)

for _ in range(epochs):
    result = train(model, train_loader, optimizer, criterion, device)
eval_loss, eval_acc = evaluate(model, test_loader, criterion, device)
print(f"Evaluation Loss: {eval_loss:.4f}, Evaluation Accuracy: {eval_acc:.4f}")

In [None]:
from pathlib import Path

# # ---- save model and config
save_path = Path("./model_checkpoints/assignment2-2-2")
if not save_path.exists():
    save_path.mkdir(parents=True)

save_dict = {
    "batch_size": batch_size,
    "epochs": epochs,
    "seed": seed,
    "optimizer": str(optimizer),
    "optimizer_kwargs": optimizer_kwargs,
    "test_loss": eval_loss,
    "test_accuracy": eval_acc,
    "model_state_dict": model.state_dict(),
    "model_config": model_config.to_dict(),
}

model_path = (
    save_path
    / f"{model_name}_test_loss_{eval_loss:.4f}_test_accuracy_{eval_acc:.4f}.pth"
)
torch.save(save_dict, model_path)
print(f"model saved to {model_path}")

---


# Please provide your analysis on each hyper-parameter.


---
