<a href="https://colab.research.google.com/github/zhiqiwang59/DL/blob/main/2_vlm/vlm_tutorial_practical_1_students.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# M2LS 2025: Vision-Language Models -- Practical 1
---
- Alexandre Galashov (agalashov@google.com)
- Petra Bevandic (Petra.Bevandic@fer.hr )
<br>

**Abstract:** In this tutorial practical we will focus on Vision Transformer (ViT) as a powerful approach for encoding images.

**Disclaimer:** You will mainly be required to complete code blocks which we noted as **"Your code here"**. We took care of most of the boilerplate code for you. However, please also feel free to deviate from the code which we prepared and code things in the way you feel is right!

# A Deep Dive into the Vision Transformer (ViT) 👁️

This section is dedicated to the Vision Transformer (ViT), the model that successfully adapted the Transformer architecture for image data.

**Our learning objectives are**:

* To understand the ViT architecture, with a special focus on the image-to-patch tokenization process.

* To implement the key components of a ViT in code.

* To train our implementation on a simple image classification task to validate our understanding.

The original [ViT](https://arxiv.org/abs/2010.11929) paper is available here for further reading.

In [1]:
#@title Necessary imports
from PIL import Image
import requests
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import math
import matplotlib.pyplot as plt
import numpy as np
import time

### Vision transformer Architecture

Let's look at the **[Vision Transformer](https://arxiv.org/abs/2010.11929) (ViT)** architecture, illustrated in the diagram. The core idea of ViT is to adapt the successful [Transformer](https://arxiv.org/abs/1706.03762) model, originally from NLP, to process images. This is achieved through a simple, effective process:

* **Image Patchification & Embedding**: The input image is first deconstructed into a sequence of fixed-size, non-overlapping patches. You can think of these patches as the visual equivalent of "words" 🖼️. Each patch is then flattened and linearly projected into a vector. This creates a sequence of "image tokens" that the Transformer can understand.

* **Transformer Encoder**: Finally, this sequence of tokens is fed into a standard Transformer Encoder, which processes the relationships between the patches to understand the image as a whole.

* **(Optional) MLP Head**: For a downstream task like image classification, the output representation from the Transformer is passed to a final MLP (Multi-Layer Perceptron) head, which produces the final prediction (e.g., the object class). 🎯

<img src="https://drive.google.com/uc?export=view&id=1UxoNoTXy39aIL8RPnpMVbX7wzmXCKlq6">



### From Pixels to a Sequence: How ViT Processes an Image

The first challenge for a Vision Transformer (ViT) is to convert a 2D image into a 1D sequence of vectors, the format a standard Transformer expects. This is done through a clever input pipeline:

**Image Patching (Patchification) 🖼️**. The input image is first divided into a grid of fixed-size, non-overlapping patches. This operation is often called "patchification," and each patch is treated as a single token.

* **Example**: A 256x256 pixel image, divided into 16x16 pixel patches, yields a 16x16 grid of patches, for a total of 256 patches.

**Linear Embedding 📏**. Each patch is flattened into a long vector and then linearly projected into a consistent vector size, known as the embedding dimension (e.g., 768 dimensions). We now have 256 token embeddings.

**Adding Positional Embeddings 📍** The Transformer architecture itself doesn't know the order of tokens. To retain the spatial information of where each patch came from, a learnable positional embedding is added to each patch embedding. This is done in the same way as for word tokens in NLP Transformers.

**Prepending the [CLS] Token ➕**. Following the convention from [BERT](https://arxiv.org/abs/1810.04805), an extra learnable embedding—the [CLS] (class) token—is prepended to the start of the sequence. For classification tasks, the final output corresponding to this token is used to represent the entire image.

**The Final Sequence ✅**. The sequence is now ready. It is fed into the Transformer Encoder and consists of the [CLS] token plus the 256 patch embeddings (each with its positional information), making for a total input sequence length of 257 tokens.


In [None]:
#@title Load Original image
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
image = image.resize(size=(256, 256))
image

In [None]:
#@title Let's visualize patches from this image
def visualize_patches(image, patch_size):
    """
    Visualizes how an image would be divided into patches.
    This is a conceptual visualization and doesn't involve the actual
    embedding process, just the division.
    """
    img_width, img_height = image.size
    num_patches_w = img_width // patch_size
    num_patches_h = img_height // patch_size

    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image)
    ax.set_title(f"Original Image with {patch_size}x{patch_size} Patches")

    for i in range(num_patches_w):
        for j in range(num_patches_h):
            # Draw rectangles to represent patches
            rect = plt.Rectangle((i * patch_size, j * patch_size),
                                 patch_size, patch_size,
                                 linewidth=1, edgecolor='r', facecolor='none')
            ax.add_patch(rect)
    plt.axis('off')
    plt.show()
conceptual_patch_size = 16
visualize_patches(image, conceptual_patch_size)
print(f"The image is conceptually divided into {conceptual_patch_size}x{conceptual_patch_size} patches.")
print("-" * 50)

## Exercise: Implement ViT Image Processing

Our aim is to implement the ViT's image pre-processing module. The implementation will cover:

* **Image Embedding**: Converting image pixels into a sequence of patch vectors.

* **Position Embedding**: Injecting spatial information for each patch.

* **[CLS] Token**: Adding an optional, global representation token.

* **Final Assembly** : Combining these components into the final sequence for the Transformer.

Essentially, we will implement the left part of the figure:

<img src="https://drive.google.com/uc?export=view&id=1UxoNoTXy39aIL8RPnpMVbX7wzmXCKlq6">

### **Task 1**: Implement image patchification & embedding.

You need to implement `nn.Module` (see below), which takes an images, splits it into patches and embeds each patch to a hidden dimension.

**Tip**: You can use convolution `nn.Conv2d` module to achieve this goal.

In [None]:
class ImagePatchificationAndEmbedding(nn.Module):
  """Splits the image into patches and embeds them.

  Receves the transposed images of shape
  `<batch_size,in_channels,img_size,img_size>`. It creates patches of size
  `patch_size` with a total number of patches equal to
  `num_patches=(img_size/patch_size)^2`. Each patch is embedded into the
  `embed_dim`. Therefore, it results in a tensor of shape
  `<batch_size,num_patches,embed_dim>`.
  """
  def __init__(self, img_size:int, patch_size:int, in_channels:int, embed_dim:int):
    super().__init__()
    self.img_size = img_size
    self.patch_size = patch_size
    self.num_patches = (img_size // patch_size) ** 2
    self.in_channels = in_channels
    ############################################################################
    # Your code here
    ############################################################################
    ...

  def forward(self, transposed_image: torch.Tensor):
    x = transposed_image
    batch_size = x.shape[0]
    assert x.shape == (batch_size, self.in_channels, self.img_size, self.img_size)

    ############################################################################
    # Your code here
    ############################################################################
    ...

    return x

In [None]:
def test_patch_embedding():
  batch_size = 7
  in_channels = 3
  img_size = 256
  patch_size = 16
  embed_dim = 768
  transposed_image = torch.ones((batch_size, in_channels, img_size, img_size))
  patch_embedding = ImagePatchificationAndEmbedding(
    img_size=img_size,
    patch_size=patch_size,
    in_channels=in_channels,
    embed_dim=embed_dim)
  with torch.no_grad():
    output = patch_embedding(transposed_image)
  num_patches = (img_size // patch_size)**2
  if output.shape != (batch_size, num_patches, embed_dim):
    raise ValueError(f'The shape of `ImagePatchificationAndEmbedding` output must be`{(batch_size, num_patches, embed_dim)}`, but got `{output.shape}`.')
  print('ImagePatchificationAndEmbedding is implemented correctly.')
test_patch_embedding()

### **Task 2**: Implement position embedding

Now that you have implemented `ImagePatchEmbedding`, let's implement position embedding.

Checking [transformer](https://arxiv.org/abs/1706.03762) paper, the position embedding (PE) is a tensor of shape `<SEQUENCE_LENGTH, EMBEDDING_DIM>`, where:

* `SEQUENCE_LENGTH` is a length of a sequence (in our case it is a number of patches).
* `EMBEDDING_DIM` is the embedding dimension, but it must be divisible by 2

See the formula below (here $d_{model}$ is the `EMBEDDING_DIM`)

<img src="https://drive.google.com/uc?export=view&id=1o9Sx9dJ6DTE2dozaN62eCCb6FvJg3syF">


In [None]:
def compute_position_embedding(sequence_length: int, embed_dim: int):
  assert embed_dim % 2 == 0, "`embed_dim` must be divisible by 2."
  position_embedding = torch.zeros(sequence_length, embed_dim)
  position = torch.arange(0, sequence_length, dtype=torch.float)
  position = torch.reshape(position, (sequence_length, 1))

  ##############################################################################
  # Your code here
  ##############################################################################
  sin_terms = ...
  cos_terms = ...

  # All even terms are sin
  position_embedding[:, 0::2] = sin_terms
  # All odd terms are cos
  position_embedding[:, 1::2] = cos_terms

  return position_embedding

In [None]:
def test_position_embeddings():
  sequence_length = 16
  embed_dim = 768
  position_embedding = compute_position_embedding(sequence_length, embed_dim)
  # Check the shapes
  if position_embedding.shape != (sequence_length, embed_dim):
    raise ValueError(f'`position_embedding` shape must be `{(sequence_length, embed_dim)}` but is {position_embedding.shape}')
  # Check that even & odd terms are sin and cos
  sin_terms = position_embedding[:, 0::2]
  cos_terms = position_embedding[:, 1::2]
  num_matches = torch.sum(torch.abs(sin_terms**2 + cos_terms**2 - 1.) <= 1e-5)
  if num_matches != sequence_length*(embed_dim // 2):
    raise ValueError('Position embedding sin and cos terms should have a sum of squares equal to 1.')
  print('Position embedding is implemented correctly.')
test_position_embeddings()

### **Task 3**: Final assembly -- full image embedding + `[CLS]` token
Now, you are going to implement `ImageEmbedding` modul, which takesan image, does the pathification and embedding, adds positional embedding and optionally a `[CLS]` token.

In [None]:
class ImageEmbedding(nn.Module):
  """Embeds image into token embeddings for Vision transformer.

  Applies patch embedding, position embedding and adds optionally adds a class embedding."""
  def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int, add_cls_token: bool = True):
    super().__init__()
    # Save arguments
    self.img_size = img_size
    self.patch_size = patch_size
    self.in_channels = in_channels
    self.embed_dim = embed_dim
    self.add_cls_token = add_cls_token

    # Patch embedding
    self.patch_embedding = ImagePatchificationAndEmbedding(img_size, patch_size, in_channels, embed_dim)
    self.num_patches = self.patch_embedding.num_patches

    # Compute num tokens
    if add_cls_token:
      self.num_tokens = self.num_patches + 1
      ##########################################################################
      # Your code here
      ##########################################################################
      ...

    else:
      self.num_tokens = self.num_patches

    # Position embedding
    self.position_embeding = compute_position_embedding(self.num_tokens, embed_dim)
    self.register_buffer('positional_embedding', self.position_embeding)

  def forward(self, transposed_image: torch.Tensor) -> torch.Tensor:
    batch_size = transposed_image.shape[0]
    # Compute patch embedding
    ############################################################################
    # Your code here
    ############################################################################
    tokens_emb = ...
    assert tokens_emb.shape == (batch_size, self.num_patches, self.embed_dim)
    if self.add_cls_token:
      # Add class embedding
      ##########################################################################
      # Your code here
      ##########################################################################
      ...

    assert tokens_emb.shape == (batch_size, self.num_tokens, self.embed_dim)

    # Add position embedding
    ############################################################################
    # Your code here
    ############################################################################
    ...

    return tokens_emb

In [None]:
def test_image_embedding():
  batch_size = 7
  in_channels = 3
  img_size = 256
  patch_size = 16
  embed_dim = 768
  transposed_image = torch.ones((batch_size, in_channels, img_size, img_size))
  # Not using class embedding
  patch_embedding = ImageEmbedding(
    img_size=img_size,
    patch_size=patch_size,
    in_channels=in_channels,
    embed_dim=embed_dim,
    add_cls_token=False)
  with torch.no_grad():
    output = patch_embedding(transposed_image)
  num_patches = (img_size // patch_size)**2
  if output.shape != (batch_size, num_patches, embed_dim):
    raise ValueError(f'The shape of `ImageEmbedding` output must be`{(batch_size, num_patches, embed_dim)}`, but got `{output.shape}`.')

  # Using [CLS] token
  patch_embedding = ImageEmbedding(
    img_size=img_size,
    patch_size=patch_size,
    in_channels=in_channels,
    embed_dim=embed_dim,
    add_cls_token=True)
  with torch.no_grad():
    output = patch_embedding(transposed_image)
  num_patches = (img_size // patch_size)**2
  if output.shape != (batch_size, num_patches + 1, embed_dim):
    raise ValueError(f'The shape of `ImageEmbedding` output must be`{(batch_size, num_patches + 1, embed_dim)}`, but got `{output.shape}`.')

  print('ImageEmbedding is implemented correctly.')
test_image_embedding()

Good job! You have implemented all the image embedding correctly! Now let's see if it works together in ViT.

### **Task 4**: Run following code -- building blocks for ViT

We already pre-implemented ViT for you. Below are the building blocks.

Just run the code!

##### **MLP block**

In [None]:
class MLP(nn.Module):
  """Multi-Layer Perceptron for the Transformer block."""
  def __init__(self, embed_dim: int, mlp_dim: int, dropout_rate: float = 0.0):
    super().__init__()
    self.fc1 = nn.Linear(embed_dim, mlp_dim)
    self.gelu = nn.GELU()
    self.dropout1 = nn.Dropout(dropout_rate)
    self.fc2 = nn.Linear(mlp_dim, embed_dim)
    self.dropout2 = nn.Dropout(dropout_rate)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.fc1(x)
    x = self.gelu(x)
    x = self.dropout1(x)
    x = self.fc2(x)
    x = self.dropout2(x)
    return x

#####**Multi-headed attention module**

In [None]:
# --- Small Vision Transformer (ViT) Model Definition ---

class MultiHeadSelfAttention(nn.Module):
  """Multi-Head Self-Attention mechanism."""
  def __init__(self, embed_dim: int, num_heads: int, dropout_rate: float = 0.0):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads
    assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"

    self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)
    self.out_proj = nn.Linear(embed_dim, embed_dim)
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    batch_size, num_tokens, embed_dim = x.shape

    # Generate Q, K, V matrices
    qkv = self.qkv_proj(x).reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
    qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, num_tokens, head_dim)
    q, k, v = qkv[0], qkv[1], qkv[2]

    # Scaled Dot-Product Attention
    # (batch_size, num_heads, num_tokens, head_dim) @ (batch_size, num_heads, head_dim, num_tokens)
    attention_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
    attention_weights = torch.softmax(attention_scores, dim=-1)
    attention_weights = self.dropout(attention_weights)

    # (batch_size, num_heads, num_tokens, head_dim)
    output = (attention_weights @ v).transpose(1, 2).reshape(batch_size, num_tokens, embed_dim)

    output = self.out_proj(output)
    return output

##### **Transformer encoder block**

In [None]:
class TransformerEncoderBlock(nn.Module):
  """Single Transformer Encoder Block."""
  def __init__(self, embed_dim:int, num_heads:int, mlp_dim:int, dropout_rate: float = 0.0):
    super().__init__()
    self.norm1 = nn.LayerNorm(embed_dim)
    self.attn = MultiHeadSelfAttention(embed_dim, num_heads, dropout_rate)
    self.norm2 = nn.LayerNorm(embed_dim)
    self.mlp = MLP(embed_dim, mlp_dim, dropout_rate)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = x + self.attn(self.norm1(x))
    x = x + self.mlp(self.norm2(x))
    return x

### **Task 5**: Add ImageEncoder to Vision Transformer (ViT) below

Now, we are going to put everything together into `VisionTransformer` module. Your objective is to add `ImageEmbedding` module and to simply embed the images when you call `ViT`.

In [None]:
class VisionTransformer(nn.Module):
  """Small Vision Transformer."""
  def __init__(self,
               img_size:int = 32,
               patch_size: int = 4,
               in_channels: int = 3,
               num_classes: int = 10,
               embed_dim: int = 256,
               num_heads: int = 8,
               num_layers: int = 6,
               mlp_dim: int = 512,
               dropout_rate: float = 0.1,
               add_cls_token: bool = True):
      super().__init__()

      # Image embedding.
      ##########################################################################
      # Your code here
      ##########################################################################
      self.image_embedding = ...



      self.dropout = nn.Dropout(dropout_rate)

      self.transformer_encoder = nn.Sequential(
          *[TransformerEncoderBlock(embed_dim, num_heads, mlp_dim, dropout_rate) for _ in range(num_layers)]
      )

      self.norm = nn.LayerNorm(embed_dim)
      self.head = nn.Linear(embed_dim, num_classes)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    ############################################################################
    # Your code here
    ############################################################################
    tokens_emb = ...

    out = self.dropout(tokens_emb)

    # Transformer encoder
    out = self.transformer_encoder(out)

    # Take the output of the class token for classification
    out = self.norm(out[:, 0])
    logits = self.head(out)
    return logits

In [None]:
def test_vit():
  batch_size = 13
  in_channels = 3
  img_size = 32
  patch_size = 4
  num_classes = 17
  transposed_img = torch.ones((batch_size, in_channels, img_size, img_size))
  vit = VisionTransformer(in_channels=in_channels, img_size=img_size, patch_size=patch_size, num_classes=num_classes)
  with torch.no_grad():
    out = vit(transposed_img)
  if out.shape != (batch_size, num_classes):
    raise ValueError(f'The shape of `VisionTransformer` output must be`{(batch_size, num_classes)}`, but got `{out.shape}`.')
  print('VisionTransformer is implemented correctly.')
test_vit()


### **Task 6**: Run training of ViT on a toy dataset

Now that we have ViT implemented, we are going to train an image classifier on a toy dataset.

There is nothing to implement, just run the code below and check what accuracy you can get!

In [None]:
from torch.utils.data import Subset

In [None]:
class MNISTColorDataset(Dataset):
  def __init__(self, mnist_split):
    self.mnist_split = mnist_split

  def __len__(self):
    return len(self.mnist_split)

  def __getitem__(self, idx):
    batch_item = {}
    mnist_image = self.mnist_split[idx][0]
    batch_item["input"] = torch.cat([mnist_image, mnist_image, mnist_image], axis=0)
    batch_item["label"] = torch.LongTensor(np.array([self.mnist_split[idx][1]]))
    return batch_item

In [None]:
subset_size = 10_000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
subset_indices = range(subset_size)
mnist_trainset = Subset(mnist_trainset, subset_indices)

mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

mnist_train = MNISTColorDataset(mnist_trainset)
mnist_test = MNISTColorDataset(mnist_testset)

# Define DataLoaders
batch_size = 128
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Initialize the model
model = VisionTransformer(
    img_size=28,
    patch_size=4,
    in_channels=3,
    num_classes=10,
    embed_dim=192,  # Smaller embedding dimension for a "small" ViT
    num_heads=6,    # Fewer attention heads
    num_layers=6,   # Fewer transformer layers
    mlp_dim=384,    # Smaller MLP dimension
    dropout_rate=0.1
).to(device)
model.image_embedding.position_embeding = model.image_embedding.position_embeding.to(device)

# --- 3. Optimizer and Loss Function ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) # Cosine annealing scheduler

# --- 4. Training and Evaluation Loop ---
num_epochs = 5 # You might need more epochs for better performance (e.g., 100-200)

In [None]:
train_losses = []
test_losses = []
test_accuracies = []

print("\nStarting training...")
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    start_time = time.time()

    for i, batch in enumerate(train_loader):
        inputs, labels = batch['input'].to(device), batch['label'][:,0].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_train_loss = running_loss / len(train_loader)
    train_losses.append(epoch_train_loss)
    scheduler.step() # Update learning rate

    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            inputs, labels = batch['input'].to(device), batch['label'][:,0].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_test_loss = test_loss / len(test_loader)
    test_losses.append(epoch_test_loss)
    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)

    end_time = time.time()
    epoch_duration = end_time - start_time

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {epoch_train_loss:.4f}, "
          f"Test Loss: {epoch_test_loss:.4f}, "
          f"Test Accuracy: {accuracy:.2f}%, "
          f"Time: {epoch_duration:.2f}s")

print("\nTraining finished!")

In [None]:
# --- 5. Plotting Results ---
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy', color='green')
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# --- 6. Final Evaluation on Test Set ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch['input'].to(device), batch['label'][:,0].to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

final_accuracy = 100 * correct / total
print(f"\nFinal Accuracy on the 10,000 test images: {final_accuracy:.2f}%")

If you see the accuracy around 90\%, then everything is correct!