In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

from PIL import Image
import numpy as np
import pandas as pd
import os
import glob

In [5]:
class ViTDataset(Dataset):
    def __init__(self, image_paths, labels, image_size=224, patch_size=16, transform=None):
        """
        Args:
            image_paths (list of str): List of file paths to the images.
            labels (list of int): List of labels corresponding to each image.
            image_size (int): Size to which images should be resized (e.g., 224x224).
            patch_size (int): Size of each patch (e.g., 16x16).
            transform (callable, optional): Optional transform to be applied to each sample.
        """
        self.image_paths = image_paths
        self.labels = labels
        self.image_size = image_size
        self.patch_size = patch_size
        self.transform = transform
        
    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.image_paths)

    def _image_to_patches(self, image):
        """
        Convert image into patches.
        
        Args:
            image (PIL.Image or Tensor): The input image.
        
        Returns:
            patches (Tensor): A tensor of patches shaped (num_patches, patch_dim).
        """
        # Resize image to the required input size for ViT (224x224 by default)
        image = image.resize((self.image_size, self.image_size))
        
        # Convert to tensor
        image = transforms.ToTensor()(image)
        
        # Divide the image into patches
        patches = []
        for i in range(0, self.image_size, self.patch_size):
            for j in range(0, self.image_size, self.patch_size):
                patch = image[:, i:i+self.patch_size, j:j+self.patch_size]  # Extract a patch
                patches.append(patch.flatten())  # Flatten each patch into a vector
        
        # Stack patches to get a tensor of shape (num_patches, patch_dim)
        return torch.stack(patches)

    def __getitem__(self, idx):
        """Load and return a sample from the dataset at the given index."""
        # Load the image
        image = Image.open(self.image_paths[idx]).convert("RGB")
        
        # Convert image to patches
        patches = self._image_to_patches(image)
        
        # Get the label for this image
        label = self.labels[idx]
        
        # If a transform is specified, apply it
        if self.transform:
            patches = self.transform(patches)
        
        return patches, label
        

In [6]:
class PatchPositionalEmbeddingWithCLS(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_channels=3, embedding_dim=768):
        """
        Args:
            image_size (int): The size of the input image (e.g., 224x224).
            patch_size (int): The size of each patch (e.g., 16x16).
            num_channels (int): Number of input channels (3 for RGB).
            embedding_dim (int): The size of the embedding (transformer token dimension).
        """
        super(PatchPositionalEmbeddingWithCLS, self).__init__()

        # Image size and patch size parameters
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.embedding_dim = embedding_dim
        
        # Compute the number of patches in the image
        self.num_patches = (image_size // patch_size) ** 2
        
        # The patch size is (patch_size, patch_size) in height/width, so the flattened patch dimension
        self.patch_dim = patch_size * patch_size * num_channels
        
        # Linear projection layer to embed patches
        self.patch_projection = nn.Linear(self.patch_dim, self.embedding_dim)

        # Learnable positional embeddings (shape: [1, num_patches, embedding_dim])
        self.positional_embeddings = nn.Parameter(torch.randn(1, self.num_patches, self.embedding_dim))

        # Learnable class token ([CLS] token)
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embedding_dim))

    def forward(self, x):
        """
        Forward pass:
        1. Project patches to embedding space.
        2. Add positional embeddings.
        3. Prepend the [CLS] token to the patch embeddings.
        """
        # Step 1: Project patches to embedding space
        patch_embeddings = self.patch_projection(x)
        
        # Step 2: Add positional embeddings (broadcasted over batch)
        patch_embeddings_with_pos = patch_embeddings + self.positional_embeddings
        
        # Step 3: Prepend the [CLS] token to the patch embeddings
        cls_tokens = self.cls_token.expand(x.size(0), -1, -1)  # Expand the [CLS] token to the batch size
        patch_embeddings_with_cls = torch.cat((cls_tokens, patch_embeddings_with_pos), dim=1)
        
        return patch_embeddings_with_cls


In [8]:
class EncoderBlock(nn.Module):
    def __init__(self, embedding_dim=768, num_heads=12, ff_hidden_dim=3072, dropout=0.1):
        """
        Args:
            embedding_dim (int): The dimension of the patch embeddings (e.g., 768).
            num_heads (int): The number of attention heads for multi-head attention.
            ff_hidden_dim (int): The hidden dimension of the feedforward network (usually larger than embedding_dim).
            dropout (float): Dropout rate for regularization.
        """
        super(EncoderBlock, self).__init__()

        # Multi-Head Self Attention (MSA)
        self.attention = nn.MultiheadAttention(embed_dim=embedding_dim, num_heads=num_heads, dropout=dropout)

        # Feedforward Network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, ff_hidden_dim),  # First linear layer
            nn.GELU(),                               # Activation
            nn.Dropout(dropout),                     # Dropout
            nn.Linear(ff_hidden_dim, embedding_dim)  # Second linear layer
        )

        # LayerNorm layers (pre-LayerNorm)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)

        # Dropout layers for regularization
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Forward pass through the encoder block:
        1. Apply LayerNorm before Multi-Head Self Attention (MSA).
        2. Multi-Head Self Attention (MSA).
        3. Add residual connections and apply LayerNorm.
        4. Apply LayerNorm before Feedforward Network (FFN).
        5. Feedforward Network (FFN).
        6. Add residual connections and apply LayerNorm.
        """
        # Step 1: Apply LayerNorm before MSA
        x_norm = self.norm1(x)  # Normalize input to attention layer
        attn_output, _ = self.attention(x_norm, x_norm, x_norm)  # Self-attention
        
        # Step 2: Add residual connection and apply LayerNorm
        x = x + self.dropout(attn_output)
        x_norm = self.norm2(x)  # Normalize input to FFN layer
        
        # Step 3: Feedforward Network (FFN)
        ffn_output = self.ffn(x_norm)
        
        # Step 4: Add residual connection and apply LayerNorm
        x = x + self.dropout(ffn_output)
        return x


In [9]:
class MLPHead(nn.Module):
    def __init__(self, embedding_dim=768, num_classes=1000):
        """
        Args:
            embedding_dim (int): The dimension of the patch embeddings (e.g., 768).
            num_classes (int): The number of classes for classification.
        """
        super(MLPHead, self).__init__()

        # First fully connected layer (from embedding dimension to 2048)
        self.fc1 = nn.Linear(embedding_dim, 2048)

        # Second fully connected layer (from 2048 to the number of classes)
        self.fc2 = nn.Linear(2048, num_classes)

        # GELU activation function after the first fully connected layer
        self.gelu = nn.GELU()

    def forward(self, cls_token_output):
        """
        Forward pass through the MLP head:
        1. Apply the first fully connected layer and GELU activation.
        2. Apply the second fully connected layer to produce the final logits.
        """
        # Step 1: Apply the first fully connected layer followed by GELU activation
        x = self.gelu(self.fc1(cls_token_output))  # Shape: [batch_size, 2048]

        # Step 2: Apply the second fully connected layer (output layer) to produce logits
        x = self.fc2(x)  # Shape: [batch_size, num_classes]

        # Return the logits (raw class scores before applying softmax)
        return x


In [10]:
class ViTModel(nn.Module):
    def __init__(self, image_size=224, patch_size=16, embedding_dim=768, num_heads=12, ff_hidden_dim=3072, num_classes=1000, dropout=0.1):
        """
        Args:
            image_size (int): The size of the input image (e.g., 224x224).
            patch_size (int): The size of each patch (e.g., 16x16).
            embedding_dim (int): The dimension of the patch embeddings (e.g., 768).
            num_heads (int): The number of attention heads for multi-head attention.
            ff_hidden_dim (int): The hidden dimension of the feedforward network.
            num_classes (int): The number of output classes for classification.
            dropout (float): Dropout rate for regularization.
        """
        super(ViTModel, self).__init__()

        # Patch + Positional Embedding with [CLS] token
        self.patch_positional_embedding = PatchPositionalEmbeddingWithCLS(
            image_size=image_size,
            patch_size=patch_size,
            embedding_dim=embedding_dim
        )

        # Encoder Block: This processes the image patches including the [CLS] token
        self.encoder_block = EncoderBlock(
            embedding_dim=embedding_dim, 
            num_heads=num_heads, 
            ff_hidden_dim=ff_hidden_dim, 
            dropout=dropout
        )
        
        # MLP Head: Classification head that uses the [CLS] token's output
        self.mlp_head = MLPHead(embedding_dim=embedding_dim, num_classes=num_classes)

    def forward(self, x):
        """
        Forward pass through the model:
        1. Input tensor x (batch_size, num_channels, image_size, image_size).
        2. Pass through PatchPositionalEmbeddingWithCLS (to create embeddings and add [CLS] token).
        3. Pass through encoder block (self-attention).
        4. Extract the output of the [CLS] token (first token in the sequence).
        5. Pass the [CLS] token's output to the MLP head for classification.
        """
        # Step 1: Pass the input image through the patch embedding and positional encoding layer
        x = self.patch_positional_embedding(x)  # Output shape: [batch_size, num_patches+1, embedding_dim]
        
        # Step 2: Pass the encoded patches through the encoder block (self-attention and feedforward)
        encoder_output = self.encoder_block(x)  # Output shape: [batch_size, num_patches+1, embedding_dim]
        
        # Step 3: Extract the [CLS] token's output (first token in the sequence)
        cls_token_output = encoder_output[:, 0, :]  # Shape: [batch_size, embedding_dim]
        
        # Step 4: Pass the [CLS] token's output through the MLP head for classification
        logits = self.mlp_head(cls_token_output)  # Output shape: [batch_size, num_classes]
        
        return logits
